MLIR  22.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/TypeUtilities.h"
41 #include "mlir/Support/LLVM.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/SmallPtrSet.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/DebugLog.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include <type_traits>
50 
51 using namespace mlir;
52 using namespace mlir::linalg;
53 using namespace mlir::transform;
54 
55 #define DEBUG_TYPE "linalg-transforms"
56 
57 /// Attempts to apply the pattern specified as template argument to the given
58 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
59 /// function that returns the "main" result or failure. Returns failure if the
60 /// pattern failed to apply. Extra arguments are forwarded to the pattern
61 /// constructor.
62 template <typename PatternTy, typename... Args>
63 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
64  // Check if the given operation has the type expected by the pattern.
65  using OpTy = typename llvm::function_traits<
66  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
67  auto op = dyn_cast<OpTy>(operation);
68  if (!op)
69  return failure();
70 
71  // Apply the pattern directly to the op.
72  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
73  // We want to discourage direct use of PatternRewriter in APIs but In this
74  // very specific case, an IRRewriter is not enough.
75  PatternRewriter rewriter(operation->getContext());
76  rewriter.setInsertionPoint(operation);
77  auto result = pattern.returningMatchAndRewrite(op, rewriter);
78  if (failed(result))
79  return failure();
80  return cast<LinalgOp>(result->getOperation());
81 }
82 
83 /// Assuming that `ofr` is an index attr or a param of index type
84 /// or a transform dialect handle mapped to exactly one op
85 /// with one index result, return that value.
87  transform::TransformState &state, TransformOpInterface transformOp,
89  for (OpFoldResult ofr : ofrs) {
90  if (auto attr = dyn_cast<Attribute>(ofr)) {
91  if (!isa<IntegerAttr>(attr))
92  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
93  result.push_back(ofr);
94  continue;
95  }
96 
97  Value transformValue = cast<Value>(ofr);
98  if (isa<TransformParamTypeInterface>(transformValue.getType())) {
99  ArrayRef<Attribute> params = state.getParams(transformValue);
100  if (params.size() != 1)
101  return transformOp.emitDefiniteFailure()
102  << "requires exactly one parameter associated";
103  result.push_back(params[0]);
104  continue;
105  }
106 
107  auto payloadOps = state.getPayloadOps(transformValue);
108  if (!llvm::hasSingleElement(payloadOps)) {
110  transformOp.emitSilenceableError()
111  << "handle must be mapped to exactly one payload op";
112  diag.attachNote(transformValue.getLoc())
113  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
114  return diag;
115  }
116 
117  Operation *op = *payloadOps.begin();
118  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
120  transformOp.emitSilenceableError()
121  << "payload op must have exactly 1 index result";
122  diag.attachNote(op->getLoc())
123  << "has " << op->getNumResults() << " results";
124  return diag;
125  }
126  result.push_back(op->getResult(0));
127  }
128 
130 }
131 
132 // Given a list of params that are index attrs or a list of OpFoldResults
133 // that are either index attrs or op handles, return a list of OpFoldResults
134 // of index attrs or a list of OpFoldResults where all op handles are
135 // replaced with the first (and only) OpResult of that payload op.
136 // (There must be exactly one parameter associated with the AnyParamType or
137 // one mapped payload op which must have exactly one index result.)
139  transform::TransformState &state, TransformOpInterface transformOp,
140  SmallVector<OpFoldResult> &result, Value packedHandle) {
141  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
142  ArrayRef<Attribute> params = state.getParams(packedHandle);
143  for (auto param : params) {
144  if (!isa<IntegerAttr>(param))
145  return transformOp.emitDefiniteFailure()
146  << "expected the parameter to be associated with an integer "
147  "attribute";
148  result.push_back(param);
149  }
151  }
152 
153  for (Operation *op : state.getPayloadOps(packedHandle)) {
154  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
156  transformOp.emitSilenceableError()
157  << "payload op must have exactly 1 index result";
158  diag.attachNote(op->getLoc())
159  << "has " << op->getNumResults() << " results";
160  return diag;
161  }
162  result.push_back(op->getResult(0));
163  }
164 
166 }
167 
168 /// When possible, converts each `OpFoldResult` in `mixedResult` to
169 /// an integer if the value can be statically inferred. If a result
170 /// is a `Value` then it must be either a `ParamType` or a handle
171 /// to an a constant like op.
173  TransformState &state, TransformOpInterface &transformOp,
174  ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
175  for (OpFoldResult paramOrHandle : mixedResults) {
176  if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
177  reified.push_back(cast<IntegerAttr>(attr).getInt());
178  continue;
179  } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
180  ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
181  if (params.size() != 1)
182  return transformOp.emitSilenceableError() << "expected a single param";
183  reified.push_back(
184  cast<IntegerAttr>(params.front()).getValue().getSExtValue());
185  continue;
186  }
187 
188  Value handle = cast<Value>(paramOrHandle);
189  if (!isa<TransformHandleTypeInterface>(handle.getType()))
190  return transformOp.emitSilenceableError() << "unexpected value handle";
191  auto payload = state.getPayloadOps(handle);
192  if (!llvm::hasSingleElement(payload))
193  return transformOp.emitSilenceableError()
194  << "requires param or handle that is mapped to 1 payload op";
195 
196  Operation *paramOrHandlePayloadOp = *payload.begin();
197  if (paramOrHandlePayloadOp->getNumResults() != 1 ||
198  !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
199  return transformOp.emitSilenceableError()
200  << "requires param or handle to be result of op with 1 index "
201  "result";
202  }
203 
204  IntegerAttr attr;
205  if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
206  return transformOp.emitSilenceableError()
207  << "requires param or handle to be the result of a constant like "
208  "op";
209 
210  reified.push_back(attr.getInt());
211  }
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // Apply...PatternsOp
217 //===----------------------------------------------------------------------===//
218 
219 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
222 }
223 
224 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
227 }
228 
229 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
232 }
233 
234 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
238 }
239 
240 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
243  options.rankReductionStrategy =
246 }
247 
248 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
251 }
252 
253 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
256 }
257 
258 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
261 }
262 
263 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
266 }
267 
268 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // BufferizeToAllocationOp
275 //===----------------------------------------------------------------------===//
276 
277 namespace {
278 class NewOpsListener : public RewriterBase::ForwardingListener {
279 public:
281 
282  SmallVector<Operation *> getNewOps() const {
283  return SmallVector<Operation *>(newOps.begin(), newOps.end());
284  }
285 
286 private:
287  void notifyOperationInserted(Operation *op,
288  OpBuilder::InsertPoint previous) override {
289  ForwardingListener::notifyOperationInserted(op, previous);
290  // We only care about newly created ops.
291  if (previous.isSet())
292  return;
293  auto inserted = newOps.insert(op);
294  (void)inserted;
295  assert(inserted.second && "expected newly created op");
296  }
297 
298  void notifyOperationErased(Operation *op) override {
299  ForwardingListener::notifyOperationErased(op);
300  op->walk([&](Operation *op) { newOps.erase(op); });
301  }
302 
303  DenseSet<Operation *> newOps;
304 };
305 } // namespace
306 
307 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
310  // Attach listener to keep track of newly created ops.
311  OpBuilder::Listener *previousListener = rewriter.getListener();
312  auto resetListener =
313  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
314  NewOpsListener newOpsListener(previousListener);
315  rewriter.setListener(&newOpsListener);
316 
318  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
321  } else if (getMemcpyOp() == "memref.copy") {
322  options.memcpyOp =
324  } else if (getMemcpyOp() == "linalg.copy") {
325  options.memcpyOp =
327  } else {
328  llvm_unreachable("invalid memcpy op");
329  }
330  if (getAllocOp() == "memref.alloc") {
331  options.allocOp =
333  } else if (getAllocOp() == "memref.alloca") {
334  options.allocOp =
336  } else {
337  llvm_unreachable("invalid alloc op");
338  }
339  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
340  options.emitDealloc = getEmitDealloc();
341 
342  // Bufferize ops.
343  Attribute memorySpace =
344  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
345  SmallVector<Value> allocatedBuffers;
346  for (Operation *op : state.getPayloadOps(getTarget())) {
347  Value buffer =
348  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
349  if (!buffer) {
350  DiagnosedSilenceableFailure diag = emitSilenceableError()
351  << "failed to bufferize operation";
352  diag.attachNote(op->getLoc()) << "target payload op";
353  return diag;
354  }
355  allocatedBuffers.push_back(buffer);
356  }
357 
358  // Set results.
359  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
360  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
362 }
363 
364 void transform::BufferizeToAllocationOp::getEffects(
366  if (getBufferizeDestinationOnly()) {
367  // The destination is replaced with a newly allocated buffer, but the op
368  // itself remains in place.
369  onlyReadsHandle(getTargetMutable(), effects);
370  } else {
371  consumesHandle(getTargetMutable(), effects);
372  }
373  producesHandle(getOperation()->getOpResults(), effects);
374  modifiesPayload(effects);
375 }
376 
378  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
379  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
380  return emitOpError() << "unsupported memcpy op";
381  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
382  return emitOpError() << "unsupported alloc op";
383  return success();
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // PromoteTensorOp
388 //===----------------------------------------------------------------------===//
389 
390 /// Return true if the operand may be read from by its owner. This is currently
391 /// very conservative and only looks inside linalg operations to prevent
392 /// unintentional data loss.
393 static bool mayBeRead(OpOperand &operand) {
394  auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
395 
396  // Be conservative about ops we cannot analyze deeper.
397  if (!linalgOp)
398  return true;
399 
400  // Look inside linalg ops.
401  Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
402  return !blockArgument.use_empty();
403 }
404 
405 /// Return true if the value may be read through any of its uses.
406 static bool mayBeRead(Value value) {
407  // If the value has a reference semantics, it
408  // may be read through any alias...
409  if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
410  return true;
411  return llvm::any_of(value.getUses(),
412  static_cast<bool (&)(OpOperand &)>(mayBeRead));
413 }
414 
416 transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
418  transform::TransformState &state) {
419  SmallVector<Value> promoted;
420  for (Value tensor : state.getPayloadValues(getTensor())) {
421  auto type = dyn_cast<RankedTensorType>(tensor.getType());
422  if (!type) {
423  return emitSilenceableError() << "non-tensor type: " << tensor;
424  }
425 
426  Operation *definingOp = tensor.getDefiningOp();
427  if (definingOp)
428  rewriter.setInsertionPointAfter(definingOp);
429  else
430  rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
431 
432  // Check this before we emit operations using this value.
433  bool needsMaterialization = mayBeRead(tensor);
434 
435  SmallVector<Value> dynamicDims;
437  for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
438  if (!ShapedType::isDynamic(dim))
439  continue;
440  Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
441  auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
442  preservedOps.insert(dimOp);
443  dynamicDims.push_back(dimOp);
444  }
445  auto allocation = rewriter.create<bufferization::AllocTensorOp>(
446  tensor.getLoc(), type, dynamicDims);
447  // Set memory space if provided.
448  if (getMemorySpaceAttr())
449  allocation.setMemorySpaceAttr(getMemorySpaceAttr());
450  Value allocated = allocation;
451 
452  // Only insert a materialization (typically bufferizes to a copy) when the
453  // value may be read from.
454  if (needsMaterialization) {
455  auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
456  tensor.getLoc(), tensor, allocated);
457  preservedOps.insert(copy);
458  promoted.push_back(copy.getResult());
459  } else {
460  promoted.push_back(allocated);
461  }
462  rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
463  }
464  results.setValues(cast<OpResult>(getPromoted()), promoted);
466 }
467 
468 void transform::PromoteTensorOp::getEffects(
470  transform::onlyReadsHandle(getTensorMutable(), effects);
471  transform::producesHandle(getOperation()->getOpResults(), effects);
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // DecomposeOp
477 //===----------------------------------------------------------------------===//
478 
480 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
481  LinalgOp target,
483  transform::TransformState &state) {
484 #define DOWNSCALE(trans) \
485  { \
486  FailureOr<LinalgOp> res = tryApply<trans>(target); \
487  if (succeeded(res)) { \
488  results.push_back(*res); \
489  return DiagnosedSilenceableFailure::success(); \
490  } \
491  }
492 
493 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
494 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
495 
496  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
497  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
498  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
499  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
500  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
501  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
502  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
503  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
504  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
507 #undef DOWNSCALE_NORMAL
508 #undef DOWNSCALE_CALL
509 #undef DOWNSCALE
510  return emitDefaultSilenceableFailure(target);
511 }
512 
513 //===----------------------------------------------------------------------===//
514 // DecomposeInterfaceOp
515 //===----------------------------------------------------------------------===//
516 
517 // Decompose the target operation if it implements the AggregatedOpInterface.
518 // Push the decomposed operations (the ones that replaces the values produced by
519 // \p target) in the `results`.
520 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
521  transform::TransformRewriter &rewriter, Operation *target,
523  transform::TransformState &state) {
524  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
525  if (!decomposableOp) {
526  failed(rewriter.notifyMatchFailure(target,
527  "payload is not a decomposable op"));
528  return emitDefaultSilenceableFailure(target);
529  }
530 
531  FailureOr<SmallVector<Value>> maybeNewResults =
532  decomposableOp.decomposeOperation(rewriter);
533  if (failed(maybeNewResults))
534  return emitDefaultSilenceableFailure(target);
535 
536  rewriter.replaceOp(decomposableOp, *maybeNewResults);
537  for (Value val : *maybeNewResults) {
538  Operation *definition = val.getDefiningOp();
539  if (definition)
540  results.push_back(definition);
541  }
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // EliminateLinalgOpAnchoredEmptyTensorsOp
547 //===----------------------------------------------------------------------===//
548 
549 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
551  onlyReadsHandle(getTargetMutable(), effects);
552  modifiesPayload(effects);
553 }
554 
556 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
557  transform::TransformRewriter &rewriter, TransformResults &transformResults,
558  TransformState &state) {
560  options.allowReturnAllocsFromLoops = true;
561 
562  for (Operation *target : state.getPayloadOps(getTarget())) {
564  if (failed(analyzeOp(target, state)))
565  return mlir::emitSilenceableFailure(target->getLoc())
566  << "failed to analyze op";
568  rewriter, target, state)))
569  return mlir::emitSilenceableFailure(target->getLoc())
570  << "failed to eliminate LinalgOp anchored tensor.empty ops";
571  }
573 }
574 
575 //===----------------------------------------------------------------------===//
576 // FuseOp
577 //===----------------------------------------------------------------------===//
578 
579 void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
580  TypeRange loopTypes, Value target,
581  ArrayRef<int64_t> staticTileSizes,
582  ArrayRef<int64_t> staticTileInterchange,
583  bool applyCleanup, bool useForall) {
584  return build(
585  builder, result, loopTypes,
586  /*target=*/target,
587  /*mixedTileSizes=*/
588  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
589  /*mixedTileInterchange=*/
590  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
591  applyCleanup, useForall);
592 }
593 
594 void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
595  Value target, ArrayRef<int64_t> staticTileSizes,
596  ArrayRef<int64_t> staticTileInterchange,
597  bool applyCleanup, bool useForall) {
598  return build(
599  builder, result,
600  /*target=*/target,
601  /*mixedTileSizes=*/
602  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
603  /*mixedTileInterchange=*/
604  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
605  applyCleanup, useForall);
606 }
607 
608 void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
609  Value target,
610  ArrayRef<OpFoldResult> mixedTileSizes,
611  ArrayRef<OpFoldResult> mixedTileInterchange,
612  bool applyCleanup, bool useForall) {
613  // Loop types are automaticaly splat by the callee, setting up one is
614  // enough.
615  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
616  build(builder, result, loopTypes, target, mixedTileSizes,
617  mixedTileInterchange, applyCleanup, useForall);
618 }
619 
620 void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
621  TypeRange loopTypes, Value target,
622  ArrayRef<OpFoldResult> mixedTileSizes,
623  ArrayRef<OpFoldResult> mixedTileInterchange,
624  bool applyCleanup, bool useForall) {
625  SmallVector<int64_t> staticTileSizes;
626  SmallVector<Value> dynamicTileSizes;
627  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
628  SmallVector<int64_t> staticTileInterchange;
629  SmallVector<Value> dynamicTileInterchange;
630  dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
631  staticTileInterchange);
632  // Call the default builder which sets up the proper operands segment sizes
633  // attributes for multiple variadic operands. In the absence of this,
634  // horrible bugs ensue.
635  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
636  auto staticTileInterchangeAttr =
637  builder.getDenseI64ArrayAttr(staticTileInterchange);
638  unsigned numExpectedLoops =
639  useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
640  SmallVector<Type> resultTypes;
641  resultTypes.reserve(numExpectedLoops);
642  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
643  "expected one loop type or as many as loops");
644  if (loopTypes.size() == 1)
645  resultTypes.append(numExpectedLoops, loopTypes[0]);
646  else
647  llvm::append_range(resultTypes, loopTypes);
648  build(builder, result, /*transformed=*/target.getType(),
649  /*loops=*/resultTypes,
650  /*target=*/target,
651  /*tile_sizes=*/dynamicTileSizes,
652  /*tile_interchange=*/dynamicTileInterchange,
653  /*static_tile_sizes=*/staticTileSizesAttr,
654  /*static_tile_interchange=*/staticTileInterchangeAttr,
655  /*apply_cleanup=*/applyCleanup,
656  /*use_forall=*/useForall);
657 }
658 
659 /// Apply a tiling transformation to all payload ops and store both the
660 /// tiled operation as well as the created tile loops.
661 template <typename Range>
662 static LogicalResult applyTilingToAll(
663  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
664  unsigned numLoops, transform::TransformResults &transformResults,
665  function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
666  applyFn) {
667  SmallVector<Operation *> tiledLinalgOps;
668  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
669 
670  for (Operation *target : payloadOps) {
671  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
672  if (!tilingInterfaceOp)
673  return transformOp->emitError("only TilingInterface ops are supported");
674 
675  rewriter.setInsertionPoint(target);
676  FailureOr<scf::SCFTileAndFuseResult> tiledResults =
677  applyFn(tilingInterfaceOp);
678  if (failed(tiledResults))
679  return failure();
680 
681  // Perform the replacement of tiled and fused values.
682  SmallVector<Operation *> opsToReplace{target};
683  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
684  for (Operation *toReplace : opsToReplace) {
685  for (OpResult res : toReplace->getResults())
686  if (auto replacement = tiledResults->replacements.lookup(res))
687  rewriter.replaceAllUsesWith(res, replacement);
688  if (toReplace->use_empty()) {
689  rewriter.eraseOp(toReplace);
690  }
691  }
692 
693  // Report back the relevant handles to the transform op.
694  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
695  assert(tiledResults->loops.size() == numLoops &&
696  "Mismatched number of loops, tile and fuse transform should have "
697  "failed");
698  for (unsigned int i = 0; i < numLoops; ++i)
699  loopOps[i].push_back(tiledResults->loops[i]);
700  }
701 
702  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
703  for (unsigned int i = 0; i < numLoops; ++i)
704  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
705 
706  return success();
707 }
708 
710 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
711  mlir::transform::TransformResults &transformResults,
713  auto transformOp = cast<TransformOpInterface>(getOperation());
714 
715  SmallVector<int64_t> tileSizes;
717  state, transformOp, getMixedTileSizes(), tileSizes);
718  if (!status.succeeded())
719  return status;
720  SmallVector<int64_t> tileInterchange;
722  state, transformOp, getMixedTileInterchange(), tileInterchange);
723  if (!status.succeeded())
724  return status;
725 
726  scf::SCFTilingOptions tilingOptions;
727  tilingOptions.interchangeVector = tileInterchange;
728  bool useForall = getUseForall();
729  tilingOptions.setLoopType(useForall
732  SmallVector<OpFoldResult> tileSizesOfr =
733  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
734  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
735  scf::SCFTileAndFuseOptions tileAndFuseOptions;
736  tileAndFuseOptions.tilingOptions = tilingOptions;
737 
738  if (getApplyCleanup()) {
739  MLIRContext *context = rewriter.getContext();
740  RewritePatternSet patterns(context);
741  tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
744  tileAndFuseOptions.cleanupPatterns = std::move(patterns);
745  }
746 
747  size_t numLoops =
748  useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
749  LogicalResult result = applyTilingToAll(
750  rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
751  transformResults,
752  [&](TilingInterface tilingInterfaceOp)
753  -> FailureOr<scf::SCFTileAndFuseResult> {
754  return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
755  tileAndFuseOptions);
756  });
758  : DiagnosedSilenceableFailure::success();
759 }
760 
761 LogicalResult transform::FuseOp::verify() {
762  auto iterspace_rank = getStaticTileSizes().size();
763  ArrayRef<int64_t> permutation = getStaticTileInterchange();
764  if (permutation.size() > iterspace_rank)
765  return emitOpError()
766  << "interchange length exceeds iteration space dimensions ("
767  << iterspace_rank << "), found " << getTileInterchange();
768  SmallVector<bool> seen(iterspace_rank, false);
769  for (int64_t v : permutation) {
770  if (!ShapedType::isDynamic(v)) {
771  if (v < 0 || v >= static_cast<int64_t>(iterspace_rank))
772  return emitOpError() << "expects interchange values to be in range [0, "
773  << iterspace_rank << "), found: " << v;
774  if (seen[v])
775  return emitOpError() << "found duplicate interchange value: " << v;
776  seen[v] = true;
777  }
778  }
779 
780  ArrayRef<int64_t> sizes = getStaticTileSizes();
781  size_t numExpectedLoops =
782  getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
783  if (numExpectedLoops != getNumResults() - 1)
784  return emitOpError() << "expects " << numExpectedLoops << " loop results";
785 
786  return success();
787 }
788 
789 SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
790  return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext());
791 }
792 
793 SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
794  return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
795  getContext());
796 }
797 
798 void transform::FuseOp::getEffects(
800  consumesHandle(getTargetMutable(), effects);
801  onlyReadsHandle(getTileSizesMutable(), effects);
802  onlyReadsHandle(getTileInterchangeMutable(), effects);
803  producesHandle(getOperation()->getOpResults(), effects);
804  modifiesPayload(effects);
805 }
806 
807 //===----------------------------------------------------------------------===//
808 // FuseIntoContainingOp
809 //===----------------------------------------------------------------------===//
810 
811 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
812  OperationState &result,
813  Value producerOp,
814  Value containingOp) {
815  result.addOperands({producerOp, containingOp});
816  auto resultType = transform::AnyOpType::get(builder.getContext());
817  result.addTypes({resultType, resultType});
818 }
819 
820 /// Add new operands to the forall op for users of the producerOp
821 /// that are dominated by the containing scf.forall op.
823  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
824  Operation *containingOp, TilingResult &tileAndFuseResult,
825  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
826  SmallVector<OpFoldResult> &sizes) {
827 
828  // Count number of users not including the containing op
829  SetVector<Operation *> dominatedUsers;
830  DominanceInfo domInfo(containingOp);
831  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
832  if (!containingOp->isAncestor(user) &&
833  (domInfo.dominates(containingOp, user))) {
834  dominatedUsers.insert(user);
835  }
836  }
837  if (dominatedUsers.empty())
838  return nullptr;
839 
840  // Create new scf.forall op
841  auto forallOp = cast<scf::ForallOp>(containingOp);
842  OpBuilder::InsertionGuard g(rewriter);
843  rewriter.setInsertionPoint(forallOp);
844 
845  // Get new output
846  Location loc = forallOp.getLoc();
847  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
848  if (!genericOp)
849  return nullptr;
850  SmallVector<Value> outputs = genericOp.getOutputs();
851  SmallVector<Value> newOuts(forallOp.getOutputs());
852  newOuts.push_back(outputs[resultNumber]);
853 
854  // Create new scf.forall op
855  auto newforallOp = scf::ForallOp::create(
856  rewriter, loc, forallOp.getMixedLowerBound(),
857  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
858  forallOp.getMapping());
859  rewriter.eraseBlock(newforallOp.getBody());
860  newforallOp.getRegion().takeBody(forallOp.getRegion());
861 
862  // Add additional block argument for new value being returned
863  // and replaces all uses of the new output with corresponding bbArg
864  // inside the scf.forall to enable fusion into this new scf.forall.
865  newforallOp.getBody()->addArgument(newOuts.back().getType(),
866  newOuts.back().getLoc());
867  auto bbArgs = newforallOp.getBody()->getArguments();
868  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
869  [&](OpOperand &use) {
870  Operation *op = use.getOwner();
871  return newforallOp->isProperAncestor(op);
872  });
873 
874  // Fix terminator
875  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
876  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
877  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
878  Operation *firstYieldOp = yieldingOps.front();
879  rewriter.setInsertionPoint(firstYieldOp);
880  Value src = tileAndFuseResult.tiledValues[0];
881  Value dst = newforallOp.getRegionIterArgs().back();
882  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
883  tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->getLoc(), src,
884  dst, offsets, sizes, strides);
885 
886  for (auto result : llvm::enumerate(forallOp.getResults())) {
887  rewriter.replaceAllUsesWith(result.value(),
888  newforallOp->getResult(result.index()));
889  }
890  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
891  newforallOp->getResults().back(),
892  [&](OpOperand &use) {
893  Operation *user = use.getOwner();
894  return dominatedUsers.contains(user);
895  });
896  return newforallOp;
897 }
898 
899 /// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
900 /// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
901 /// outer loop. To determine the second condition, this function iterates
902 /// using a worklist over the enclosing loops, trying to find 'src' in any of
903 /// the parent loop's iter args.
904 static bool sameOrEquivalentIterArg(Value src, Value dst) {
905  // Stack like vector containing possible iterArgs candidates. The first one
906  // is dst, and we will transverse the IR from there.
907  SmallVector<Value> destWorklist;
908  destWorklist.push_back(dst);
909 
910  while (!destWorklist.empty()) {
911  Value currentDst = destWorklist.pop_back_val();
912 
913  // We have found the same operand in some iter arg in the loop structure,
914  // so src and dst are equivalent.
915  if (src == currentDst)
916  return true;
917 
918  // The operands are not equivalent, look for enclosing loops over
919  // currentDst.
920  auto bbArg = dyn_cast<BlockArgument>(currentDst);
921  if (!bbArg)
922  continue;
923 
924  Block *parentBlock = bbArg.getOwner();
925  assert(parentBlock && "unlinked block argument");
926 
927  Operation *parentOp = parentBlock->getParentOp();
928  assert(parentOp && "expected block argument with parent operation");
929 
930  // Check if parent is loop-like. If it's not, do not add it to the worklist.
931  auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
932  if (!parentLoop)
933  continue;
934 
935  for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
936  // No need to check for null as innerIterArg is tied to parentLoop.
937  OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
938  Value loopBlockArgument =
939  parentLoop->getOperand(operand->getOperandNumber());
940  destWorklist.push_back(loopBlockArgument);
941  }
942  }
943 
944  return false;
945 }
946 
947 /// Find the first "extract" user of `producerOp` and tile it right before its
948 /// use. The tiled op is fused under the `containingOp`.
949 /// Return this fused op on success or nullptr if anything fails.
950 /// If tiled op has uses that are dominated by `containingOp`, return
951 /// a new `containingOp` with results of the fused op appended to
952 /// results of the `containingOp` or nullptr if there are no dominated uses.
953 static std::tuple<SmallVector<Operation *>, Operation *>
955  Operation *producerOp, Operation *containingOp) {
956  LDBG() << "Try to fuse a direct extract use";
957  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
958  if (!tileableProducer) {
959  diag.attachNote(producerOp->getLoc())
960  << "producer is not a TileableInterface: " << *producerOp;
961  return {};
962  }
963 
964  // Search the producer slices accessed within the containing operation.
965  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
966  // evolve into an interface.
967  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
968  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
969  return sliceOp && containingOp->isProperAncestor(sliceOp);
970  });
971 
972  // Find a fusion opportunity.
973  if (it == tileableProducer->getUsers().end()) {
974  diag.attachNote(tileableProducer->getLoc())
975  << "could not find fusion opportunity for: " << *tileableProducer;
976  return {};
977  }
978  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
979 
980  // Try to fuse the producer in-place.
981  OpBuilder::InsertionGuard guard(rewriter);
982  rewriter.setInsertionPoint(sliceOpToTile);
983 
984  // Clone the producer inside the consumer and try to update the producer init
985  // operands using the loop bbArgs if applicable. More precisely, if the bbArg
986  // of the container loop points to a value that it is used by the consumer op,
987  // then, instead of using such value on the consumer, use the value coming
988  // from the bbArg instead. This allows to reuse the output tensor (instead of
989  // creating a new one) of the container when both producer and container write
990  // to the same output.
991  if (LoopLikeOpInterface containerLoop =
992  dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
993  Operation *clone = rewriter.clone(*producerOp);
994  rewriter.modifyOpInPlace(clone, [&]() {
995  // Iterate over the outputs of the producer and over the loop bbArgs and
996  // check if any bbArg points to the same value as the producer output. In
997  // such case, make the producer output point to the bbArg directly.
998  for (OpOperand &initOperandPtr :
999  cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
1000  Value producerOperand =
1001  clone->getOperand(initOperandPtr.getOperandNumber());
1002  for (BlockArgument containerIterArg :
1003  containerLoop.getRegionIterArgs()) {
1004  OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
1005  Value consumerOperand =
1006  containerLoop->getOperand(bbArg->getOperandNumber());
1007  // The producer has the same init as the loop bbArg, use it.
1008  if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
1009  initOperandPtr.set(containerIterArg);
1010  }
1011  }
1012  }
1013  });
1014 
1015  tileableProducer = dyn_cast<TilingInterface>(clone);
1016  }
1017 
1018  // Tile the producer.
1019  int64_t resultNumber =
1020  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
1021  LDBG() << "resultNumber: " << resultNumber;
1022 
1023  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
1024  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
1025 
1026  FailureOr<TilingResult> tileAndFuseResult =
1027  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
1028  sizes);
1029 
1030  if (failed(tileAndFuseResult)) {
1031  diag.attachNote(tileableProducer->getLoc())
1032  << "failed to tile producer op: " << *tileableProducer;
1033  return {};
1034  }
1035 
1036 #ifndef NDEBUG
1037  for (auto *tiledOp : tileAndFuseResult->tiledOps) {
1038  LDBG() << "tiledProducer: " << *tiledOp;
1039  }
1040 #endif
1041 
1042  // Replace the extract op.
1043  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1044  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1045  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1046  if (failed(maybeRankReduced)) {
1047  diag.attachNote(producerOp->getLoc())
1048  << "shape types don't match (missing canonicalization?):\nTiledOp: "
1049  << tileAndFuseResult->tiledValues[0]
1050  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
1051  return {};
1052  }
1053  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
1054 
1055  // Add new outputs to containing op, if required
1056  Operation *newContainingOp = replaceForAllWithNewSignature(
1057  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
1058  resultNumber, offsets, sizes);
1059 
1060  // Cleanup clone.
1061  if (dyn_cast<LoopLikeOpInterface>(containingOp))
1062  rewriter.eraseOp(tileableProducer);
1063 
1064  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
1065 }
1066 
1067 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
1068 /// it is exactly the `containingOp`, otherwise bail.
1069 /// Then, find the first "extract" user of the tied block argument and tile it
1070 /// right before its "extract" use. The tiled op is fused under the
1071 /// `containingOp`.
1072 /// Return this fused op on success or nullptr if anything fails.
1075  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
1076  Operation *containingOp) {
1077  LDBG() << "Try to fuse an extract use through block argument";
1078 
1079  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
1080  if (!tileableProducer) {
1081  diag.attachNote(producerOp->getLoc())
1082  << "producer is not a TileableInterface: " << *producerOp;
1083  return {};
1084  }
1085 
1086  // Search the first use by a "scf::ForallOp" user.
1087  scf::ForallOp forallOp;
1088  auto itProducerUses =
1089  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
1090  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
1091  return forallOp;
1092  });
1093  // If it's not from the containing op, return.
1094  if (!forallOp || forallOp != containingOp) {
1095  diag.attachNote(tileableProducer->getLoc())
1096  << "could not find a use by the containing op: " << *tileableProducer;
1097  return {};
1098  }
1099 
1100  // Search the producer slices accessed within the containing
1101  // operation.
1102  // TODO: Generalize to more extract/insert/parallel_insert triples.
1103  // Maybe evolve into an interface.
1104  OpOperand *pUse = &(*itProducerUses);
1105  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
1106 
1107  // Search the producer slices accessed within the containing operation.
1108  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
1109  // evolve into an interface.
1110  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
1111  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1112  return sliceOp && containingOp->isProperAncestor(sliceOp);
1113  });
1114 
1115  // Find a fusion opportunity.
1116  if (itBBArgUsers == bbArg.getUsers().end()) {
1117  diag.attachNote(containingOp->getLoc())
1118  << "could not find fusion opportunity for bbArg: " << bbArg;
1119  return {};
1120  }
1121  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
1122 
1123  // Try to fuse the producer in-place.
1124  OpBuilder::InsertionGuard guard(rewriter);
1125  rewriter.setInsertionPoint(sliceOpToTile);
1126 
1127  // Replace the use in the tileableProducer before tiling: clone, replace and
1128  // then tile.
1129  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
1130  LDBG() << "resultNumber: " << resultNumber;
1131 
1132  // Gather destination tensors.
1133  SmallVector<Value> destinationTensors;
1135  rewriter, tileableProducer->getLoc(), tileableProducer,
1136  destinationTensors))) {
1137  diag.attachNote(tileableProducer->getLoc())
1138  << "failed to get destination tensors for: " << *tileableProducer;
1139  return {};
1140  }
1141 
1142  IRMapping bvm;
1143  bvm.map(destinationTensors[resultNumber], bbArg);
1144  auto tileableProducerClone =
1145  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
1146  auto scopeGuard =
1147  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
1148 
1149  // Tile the producer.
1150  FailureOr<TilingResult> tileAndFuseResult =
1151  tileableProducerClone.generateResultTileValue(
1152  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
1153  sliceOpToTile.getMixedSizes());
1154  if (failed(tileAndFuseResult)) {
1155  diag.attachNote(tileableProducer->getLoc())
1156  << "failed to tile producer op: " << *tileableProducer;
1157  return {};
1158  }
1159 
1160  // Replace the extract op.
1161  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1162  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1163  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1164  assert(succeeded(maybeRankReduced) && "unexpected shape");
1165  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
1166 
1167  // Replace the use in containingOp.
1168  rewriter.modifyOpInPlace(containingOp, [&]() {
1169  containingOp->setOperand(pUse->getOperandNumber(),
1170  destinationTensors.front());
1171  });
1172 
1173  return tileAndFuseResult->tiledOps;
1174 }
1175 
1177  Operation *producerOp,
1178  Operation *containingOp) {
1179  LDBG() << "Try to fuse an use by cloning";
1180 
1181  // Gather all uses inside the containing op.
1183  for (OpResult result : producerOp->getOpResults()) {
1184  for (OpOperand &use : result.getUses()) {
1185  if (containingOp->isProperAncestor(use.getOwner())) {
1186  uses.push_back(&use);
1187  continue;
1188  }
1189  // Cannot clone and fuse if the use is by the containing op itself: fail
1190  // immediately.
1191  if (containingOp == use.getOwner()) {
1192  diag.attachNote(producerOp->getLoc())
1193  << "producer op use by containing op cannot be fused by cloning";
1194  return nullptr;
1195  }
1196  }
1197  }
1198 
1199  // Check for a non-empty list of fusion opportunities.
1200  if (uses.empty()) {
1201  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
1202  return nullptr;
1203  }
1204 
1205  // Clone and fuse inside the containing op.
1206  Operation *fusedOp = nullptr;
1207  OpOperand *use = uses.front();
1208  // Parallel insert slice is not a valid clone destination.
1209  // TODO: Generalize to other type of ops.
1210  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
1211  "Parallel insert slice is not a valid clone destination");
1212  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
1213  LDBG() << "resultNumber: " << resultNumber;
1214 
1215  OpBuilder::InsertionGuard guard(rewriter);
1216  rewriter.setInsertionPoint(use->getOwner());
1217  fusedOp = rewriter.clone(*producerOp);
1218  rewriter.modifyOpInPlace(
1219  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1220 
1221  return fusedOp;
1222 }
1223 
1224 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1225  // Allow repeated handles since we are fusing everything anyway.
1226  return true;
1227 }
1228 
1230 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
1231  transform::TransformResults &results,
1232  transform::TransformState &state) {
1233  SmallVector<Operation *> fusedOps;
1234  auto producerOps = state.getPayloadOps(getProducerOp());
1235  auto containingOps = state.getPayloadOps(getContainingOp());
1236  if (!llvm::hasSingleElement(containingOps)) {
1237  return emitDefiniteFailure()
1238  << "requires exactly one containing_op handle (got "
1239  << llvm::range_size(containingOps) << ")";
1240  }
1241  Operation *containingOp = *containingOps.begin();
1242 
1243  // If nothing to fuse, propagate success.
1244  if (std::empty(producerOps)) {
1245  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
1246  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1248  }
1249 
1250  // Helper function to find the next producer that should be fused. Take any
1251  // producer that has a use inside the containing op.
1252  SetVector<Operation *> remainingProducers(llvm::from_range, producerOps);
1253  auto getNextProducer = [&]() -> FailureOr<Operation *> {
1254  for (const auto &it : enumerate(remainingProducers)) {
1255  Operation *producerOp = it.value();
1256  // The containing op may be a user of producerOp: use isAncestor.
1257  int64_t numUsesInContainingOp =
1258  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
1259  return containingOp->isAncestor(op);
1260  });
1261  // TODO: When resolving the TODO below (no duplicate ops), take an op
1262  // that has no use among the remaining producers. This is a topological
1263  // sorting.
1264  if (numUsesInContainingOp > 0) {
1265  if (numUsesInContainingOp == 1)
1266  remainingProducers.erase(remainingProducers.begin() + it.index());
1267  return producerOp;
1268  }
1269  }
1270  return failure();
1271  };
1272 
1273  while (!remainingProducers.empty()) {
1274  auto nextProducer = getNextProducer();
1275  if (failed(nextProducer)) {
1276  auto diag = mlir::emitSilenceableFailure(getLoc())
1277  << "could not find next producer to fuse into container";
1278  diag.attachNote(containingOp->getLoc()) << "containing op";
1279  return diag;
1280  }
1281 
1282  Operation *producerOp = *nextProducer;
1283 
1284  // Default diagnostic, to be complemented with more failure information.
1286  diag << "could not fuse " << *producerOp << " into " << *containingOp;
1287 
1288  // TODO: If there are multiple uses of the producer in the containing op,
1289  // we currently tile/clone the op multiple times (once per use). In some
1290  // cases, we can tile/clone once and reuse the value for each use.
1291  // Futhermore, producers should then be traversed according to a
1292  // topological sorting.
1293  auto [tiledOps, newContainingOp] =
1294  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1295  if (!tiledOps.empty()) {
1296  LDBG() << "\nFused a direct extract use\n" << *containingOp;
1297  fusedOps.append(tiledOps);
1298  if (newContainingOp) {
1299  // Update handles associated with the containing op so we don't need to
1300  // invalidate them. This is a hack to support better composability
1301  // between tiling and fusion while a proper mechanism is being
1302  // investigated.
1303  //
1304  // DO NOT replicate this elsewhere unless you understand what you are
1305  // doing.
1306  LogicalResult replacementStatus =
1307  rewriter.notifyPayloadOperationReplaced(containingOp,
1308  newContainingOp);
1309  (void)replacementStatus;
1310  assert(succeeded(replacementStatus) &&
1311  "unable to update transform state mapping");
1312  rewriter.eraseOp(containingOp);
1313  containingOp = newContainingOp;
1314  }
1315  continue;
1316  }
1317 
1318  SmallVector<Operation *> tiledContainingOpOperand =
1320  rewriter, diag, producerOp, containingOp);
1321  if (!tiledContainingOpOperand.empty()) {
1322  LDBG() << "\nFused an extract use through block argument\n"
1323  << *containingOp;
1324  fusedOps.append(tiledContainingOpOperand);
1325  continue;
1326  }
1327 
1328  Operation *cloned =
1329  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1330  if (cloned) {
1331  LDBG() << "\nFused an use by cloning\n" << *containingOp;
1332  fusedOps.push_back(cloned);
1333  continue;
1334  }
1336  }
1337 
1338  results.set(cast<OpResult>(getFusedOp()), fusedOps);
1339  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1341 }
1342 
1343 void transform::FuseIntoContainingOp::getEffects(
1345  consumesHandle(getProducerOpMutable(), effects);
1346  onlyReadsHandle(getContainingOpMutable(), effects);
1347  producesHandle(getOperation()->getOpResults(), effects);
1348  modifiesPayload(effects);
1349 }
1350 
1351 //===----------------------------------------------------------------------===//
1352 // GeneralizeOp
1353 //===----------------------------------------------------------------------===//
1354 
1356 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1357  LinalgOp target,
1359  transform::TransformState &state) {
1360  // Exit early if no transformation is needed.
1361  if (isa<GenericOp>(target)) {
1362  results.push_back(target);
1364  }
1365  rewriter.setInsertionPoint(target);
1366  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1367  if (succeeded(generic)) {
1368  results.push_back(generic->getOperation());
1370  }
1371  return emitDefaultSilenceableFailure(target);
1372 }
1373 
1374 //===----------------------------------------------------------------------===//
1375 // SpecializeOp
1376 //===----------------------------------------------------------------------===/
1377 
1379 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1380  LinalgOp target,
1382  transform::TransformState &state) {
1383  // Exit early if the operation is not a generic.
1384  if (!isa<GenericOp>(target)) {
1385  results.push_back(target);
1387  }
1388  rewriter.setInsertionPoint(target);
1389  FailureOr<LinalgOp> named =
1390  specializeGenericOp(rewriter, cast<GenericOp>(target));
1391  if (succeeded(named)) {
1392  results.push_back(named->getOperation());
1394  }
1395  return emitDefaultSilenceableFailure(target);
1396 }
1397 
1398 //===----------------------------------------------------------------------===//
1399 // InterchangeOp
1400 //===----------------------------------------------------------------------===//
1401 
1403 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1404  GenericOp target,
1406  transform::TransformState &state) {
1407  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1408  // Exit early if no transformation is needed.
1409  if (interchangeVector.empty()) {
1410  results.push_back(target);
1412  }
1413 
1414  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1415  if (interchangeVector.size() != numLoops) {
1416  return emitSilenceableError()
1417  << getIteratorInterchangeAttrName() << " has length ("
1418  << interchangeVector.size()
1419  << ") different from the number of loops in the target operation ("
1420  << numLoops << ")";
1421  }
1422  FailureOr<GenericOp> res = interchangeGenericOp(
1423  rewriter, target, SmallVector<unsigned>(interchangeVector));
1424  if (failed(res))
1425  return emitDefiniteFailure() << "failed to apply";
1426  results.push_back(res->getOperation());
1428 }
1429 
1430 LogicalResult transform::InterchangeOp::verify() {
1431  ArrayRef<int64_t> permutation = getIteratorInterchange();
1432  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1433  if (!std::is_permutation(sequence.begin(), sequence.end(),
1434  permutation.begin(), permutation.end())) {
1435  return emitOpError()
1436  << "expects iterator_interchange to be a permutation, found "
1437  << getIteratorInterchange();
1438  }
1439  return success();
1440 }
1441 
1442 //===----------------------------------------------------------------------===//
1443 // LinalgCopyToMemrefOp
1444 //===----------------------------------------------------------------------===//
1445 
1446 DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1447  transform::TransformRewriter &rewriter, Operation *targetOp,
1449  transform::TransformState &state) {
1450 
1451  // Check if the target can be converted.
1452  if (!isa<linalg::CopyOp>(targetOp)) {
1454  emitSilenceableError() << "only linalg.copy target ops are supported";
1455  diag.attachNote(targetOp->getLoc()) << "target op";
1456  return diag;
1457  }
1458 
1459  auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1460  if (!copyOp.hasPureBufferSemantics()) {
1462  emitSilenceableError()
1463  << "cannot transform a linalg.copy on tensors into a memref.copy";
1464  diag.attachNote(targetOp->getLoc()) << "target op";
1465  return diag;
1466  }
1467 
1468  SmallVector<Value> inputs = copyOp.getInputs();
1469  SmallVector<Value> outputs = copyOp.getOutputs();
1470  assert(inputs.size() == 1 && "expected linalg copy op with one input");
1471  assert(outputs.size() == 1 && "expected memref copy op with one output");
1472  Value input = inputs.front();
1473  Value output = outputs.front();
1474 
1475  // linalg.copy supports different element types on source/dest whereas
1476  // memref.copy does not, so we must check that the source and dest types can
1477  // be handled by memref.copy and otherwise reject the transformation.
1478  if (!isa<ShapedType>(input.getType())) {
1480  emitSilenceableError()
1481  << "cannot transform a linalg.copy which input has no shape";
1482  diag.attachNote(targetOp->getLoc()) << "target op";
1483  return diag;
1484  }
1485 
1486  // linalg.copy destination must be a shaped type.
1487  assert(isa<ShapedType>(output.getType()));
1488 
1489  if (cast<ShapedType>(input.getType()).getElementType() !=
1490  cast<ShapedType>(output.getType()).getElementType()) {
1492  emitSilenceableError()
1493  << "cannot transform a linalg.copy with different source and "
1494  "destination element types ";
1495  diag.attachNote(targetOp->getLoc()) << "target op";
1496  return diag;
1497  }
1498 
1499  // Target can be converted, do it.
1500  auto memrefCopyOp =
1501  rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1502 
1503  results.push_back(memrefCopyOp);
1505 }
1506 
1507 //===----------------------------------------------------------------------===//
1508 // LowerPackOp
1509 //===----------------------------------------------------------------------===//
1510 
1511 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1512  transform::TransformRewriter &rewriter, linalg::PackOp target,
1513  transform::ApplyToEachResultList &transformResults,
1514  transform::TransformState &state) {
1515  rewriter.setInsertionPoint(target);
1516  bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1517  FailureOr<LowerPackResult> res =
1518  lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1519  if (failed(res)) {
1520  return mlir::emitSilenceableFailure(target->getLoc())
1521  << "cannot lower to pad + expand + transpose";
1522  }
1523  transformResults.push_back(res->padOp);
1524  transformResults.push_back(res->expandShapeOp);
1525  transformResults.push_back(res->transposeOp);
1527 }
1528 
1529 //===----------------------------------------------------------------------===//
1530 // LowerUnPackOp
1531 //===----------------------------------------------------------------------===//
1532 
1533 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1534  transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1535  transform::ApplyToEachResultList &transformResults,
1536  transform::TransformState &state) {
1537  rewriter.setInsertionPoint(target);
1538  bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1539  FailureOr<LowerUnPackOpResult> res =
1540  lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1541  if (failed(res)) {
1543  emitSilenceableError()
1544  << "cannot lower to transpose + collapse + extract";
1545  diag.attachNote(target->getLoc()) << "target payload op";
1546  return diag;
1547  }
1548  transformResults.push_back(res->emptyOp);
1549  transformResults.push_back(res->transposeOp);
1550  transformResults.push_back(res->collapseShapeOp);
1551  transformResults.push_back(res->extractSliceOp);
1553 }
1554 
1555 //===---------------------------------------------------------------------===//
1556 // MatchOp
1557 //===---------------------------------------------------------------------===//
1558 
1559 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1560  Value target, ArrayRef<StringRef> opNames) {
1561  result.addOperands(target);
1562  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1563  builder.getStrArrayAttr(opNames));
1564  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1565 }
1566 
1567 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1568  TypeRange resultTypes, Value target,
1569  ArrayRef<StringRef> opNames) {
1570  result.addOperands(target);
1571  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1572  builder.getStrArrayAttr(opNames));
1573  result.addTypes(resultTypes);
1574 }
1575 
1577 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1578  transform::TransformResults &results,
1579  transform::TransformState &state) {
1580  llvm::StringSet<> strs;
1581  if (getOps().has_value())
1582  strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1583 
1584  auto payloadOps = state.getPayloadOps(getTarget());
1585  if (!llvm::hasSingleElement(payloadOps)) {
1586  return emitDefiniteFailure("requires exactly one target handle");
1587  }
1588 
1590  bool incorrectNumOperandTypes = false;
1591  auto matchFun = [&](Operation *op) {
1592  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1593  return;
1594 
1595  // Interfaces cannot be matched by name, just by ID.
1596  // So we specifically encode the interfaces we care about for this op.
1597  if (getInterface().has_value()) {
1598  auto iface = getInterface().value();
1599  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1600  !isa<LinalgOp>(op))
1601  return;
1602  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1603  !isa<TilingInterface>(op))
1604  return;
1605  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1606  !isa<LoopLikeOpInterface>(op))
1607  return;
1608  }
1609 
1610  // Check if all specified attributes match.
1611  if (getOpAttrs().has_value()) {
1612  DictionaryAttr opAttrs = getOpAttrs().value();
1613  for (NamedAttribute attr : opAttrs) {
1614  if (attr.getName() == getInterfaceAttrName() ||
1615  attr.getName() == getOpsAttrName())
1616  continue;
1617  if (!op->hasAttr(attr.getName()))
1618  return;
1619  if (op->getAttr(attr.getName()) != attr.getValue())
1620  return;
1621  }
1622  }
1623 
1624  if (getFilterResultType().has_value()) {
1625  Type t = getFilterResultType().value();
1626  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1627  return;
1628  }
1629 
1630  if (getFilterOperandTypes().has_value()) {
1631  mlir::ArrayAttr types = getFilterOperandTypes().value();
1632  auto operandTypes = op->getOperandTypes();
1633 
1634  if (types.size() == 1) {
1635  // All the operands must must be equal to the specified type
1636  auto typeattr =
1637  dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1638  Type t = cast<::mlir::Type>(typeattr.getValue());
1639  if (!llvm::all_of(op->getOperandTypes(),
1640  [&](Type operandType) { return operandType == t; }))
1641  return;
1642  } else {
1643  // The operand types must match all the types in the list (in the same
1644  // order in with they are specified)
1645  if (types.size() != operandTypes.size()) {
1646  incorrectNumOperandTypes = true;
1647  return;
1648  }
1649 
1650  for (auto [attr, operandType] :
1651  llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1652  auto typeattr = cast<mlir::TypeAttr>(attr);
1653  Type type = cast<::mlir::Type>(typeattr.getValue());
1654 
1655  if (type != operandType)
1656  return;
1657  }
1658  }
1659  }
1660 
1661  // All constraints are satisfied.
1662  res.push_back(op);
1663  return;
1664  };
1665 
1666  (*payloadOps.begin())->walk(matchFun);
1667  if (incorrectNumOperandTypes)
1668  return emitDefiniteFailure("If filter_operand_types contains more than a "
1669  "type, then it must contain as much types as "
1670  "the number of operands in the target ops");
1671  results.set(cast<OpResult>(getResult()), res);
1673 }
1674 
1675 //===---------------------------------------------------------------------===//
1676 // MultiTileSizesOp
1677 //===---------------------------------------------------------------------===//
1678 
1680  Type targetType, Type lowSizeType, Type,
1681  Type) {
1682  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1683 }
1684 
1685 static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1686  Type &targetType, Type &lowSizeType,
1687  Type &highSizeType,
1688  Type &splitPointType) {
1689  FunctionType funcType;
1690  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1691  if (failed(parser.parseType<FunctionType>(funcType)))
1692  return failure();
1693 
1694  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1695  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1696  "argument and one result";
1697  }
1698  targetType = funcType.getInput(0);
1699  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1700 
1701  return success();
1702 }
1703 
1704 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1705  transform::TransformRewriter &rewriter, LinalgOp target,
1707  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1708  if (target.hasDynamicShape()) {
1709  auto diag = emitSilenceableError()
1710  << "cannot compute parametric tile sizes for dynamically "
1711  "shaped payload op";
1712  diag.attachNote(target->getLoc()) << "payload op";
1713  return diag;
1714  }
1715 
1716  FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1717  target, getDimension(), getTargetSize(), getDivisor());
1718  if (failed(spec)) {
1719  return emitSilenceableError()
1720  << "failed to compute multi-size tiling sizes";
1721  }
1722 
1723  Builder builder(target.getContext());
1724  results.assign(llvm::map_range(
1725  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1726  spec->lowTileSize * spec->lowTripCount}),
1727  [&builder, this](int64_t value) {
1728  return builder.getIntegerAttr(
1729  cast<ParamType>(getLowSize().getType()).getType(), value);
1730  }));
1732  }
1733 
1734  OpBuilder builder(target.getContext());
1735  builder.setInsertionPoint(target);
1736  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1737  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1738  FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1739  builder, target, getDimension(), targetSize, divisor);
1740  if (failed(spec)) {
1741  return emitSilenceableError() << "could not generate tile size computation";
1742  }
1743 
1744  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1745  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1746  Operation *splitPoint =
1747  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1748  {spec->lowTileSize, spec->lowTripCount});
1749  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1750  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1751  assert(lowTileSize && highTileSize && splitPoint &&
1752  "tile sizes are not produced by operations");
1753  results.reserve(results.size() + 3);
1754  results.push_back(lowTileSize);
1755  results.push_back(highTileSize);
1756  results.push_back(splitPoint);
1758 }
1759 
1760 void transform::MultiTileSizesOp::getEffects(
1762  onlyReadsHandle(getTargetMutable(), effects);
1763  producesHandle(getOperation()->getOpResults(), effects);
1764  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1765  onlyReadsPayload(effects);
1766  else
1767  modifiesPayload(effects);
1768 }
1769 
1770 LogicalResult transform::MultiTileSizesOp::verify() {
1771  if (getLowSize().getType() != getHighSize().getType() ||
1772  getLowSize().getType() != getSplitPoint().getType()) {
1773  return emitOpError() << "expects all results type to be the same";
1774  }
1775  return success();
1776 }
1777 
1778 //===---------------------------------------------------------------------===//
1779 // PackOp
1780 //===---------------------------------------------------------------------===//
1781 
1782 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1783  Value target,
1784  ArrayRef<OpFoldResult> mixedPackedSizes) {
1785  SmallVector<int64_t> staticPackedSizes;
1786  SmallVector<Value> dynamicPackedSizes;
1787  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1788  staticPackedSizes);
1789  // Call the default builder which sets up the proper operands segment sizes
1790  // attributes for multiple variadic operands. In the absence of this, horrible
1791  // bugs ensue.
1792  Type linalgOpHType = transform::OperationType::get(
1793  builder.getContext(), GenericOp::getOperationName());
1794  build(builder, result,
1795  /*resultType=*/linalgOpHType,
1796  /*target=*/target,
1797  /*dynamic_sizes=*/dynamicPackedSizes,
1798  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1799 }
1800 
1801 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1802  Builder b(getContext());
1803  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1804 }
1805 
1807 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1808  transform::TransformResults &transformResults,
1809  transform::TransformState &state) {
1810  auto targetOps = state.getPayloadOps(getTarget());
1811  // If nothing to pack, propagate success.
1812  if (std::empty(targetOps)) {
1813  transformResults.set(cast<OpResult>(getPackedOp()),
1814  ArrayRef<Operation *>({}));
1816  }
1817  // Fail on multi-op handles.
1818  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1819  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1820  return emitSilenceableError()
1821  << "requires target to map to exactly 1 LinalgOp (got "
1822  << llvm::range_size(targetOps) << ")";
1823  }
1824  // Fail on mismatched number of pack sizes.
1825  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1826  return emitSilenceableError()
1827  << "requires number of packed sizes match the number of loops ("
1828  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1829  << ")";
1830  }
1831 
1832  // Unpack handles to constants or actual SSA index values.
1833  SmallVector<OpFoldResult> packedSizes;
1835  state, *this, packedSizes, getMixedPackedSizes());
1836 
1837  rewriter.setInsertionPoint(linalgOp);
1838  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1839  if (failed(maybeResult))
1840  return emitDefiniteFailure("data tiling failed");
1841 
1842  transformResults.set(cast<OpResult>(getPackedOp()),
1843  {maybeResult->packedLinalgOp.getOperation()});
1845 }
1846 
1847 void transform::PackOp::getEffects(
1849  transform::consumesHandle(getTargetMutable(), effects);
1850  transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1851  transform::producesHandle(getOperation()->getOpResults(), effects);
1852  transform::modifiesPayload(effects);
1853 }
1854 
1855 //===---------------------------------------------------------------------===//
1856 // PackGreedilyOp.
1857 //===---------------------------------------------------------------------===//
1858 
1859 LogicalResult transform::PackGreedilyOp::verify() {
1860  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1861  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1862  << " is not a valid permutation";
1863  }
1864  // TODO: relax to allow empty once we have another strategy than just matmul.
1865  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1866  for (auto [s, nmo] :
1867  llvm::zip_equal(getMixedMatmulPackedSizes(),
1868  getMatmulPaddedSizesNextMultipleOf())) {
1869  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1870  if (nmo != 0 &&
1871  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1872  return emitOpError() << "at most one of the packed_size and the "
1873  "padded_sizes_next_multiple_of can be nonzero "
1874  "for the matmul strategy";
1875  }
1876  }
1877  }
1878  return success();
1879 }
1880 
1882 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1883  transform::TransformResults &transformResults,
1884  transform::TransformState &state) {
1885  SmallVector<Operation *> results;
1886  for (Operation *op : state.getPayloadOps(getTarget())) {
1887  auto linalgOp = dyn_cast<LinalgOp>(op);
1888  if (!linalgOp)
1889  continue;
1890  // linalgOp will be replaced and the insertion point may be invalidated if
1891  // we set it before -> set it after.
1892  rewriter.setInsertionPointAfter(linalgOp);
1893  // Failing to pack greedily is perfectly fine.
1894  // In the future we will want to order packings according to some metric.
1895  FailureOr<PackResult> packResult = packMatmulGreedily(
1896  /*rewriter=*/rewriter,
1897  /*linalgOp=*/linalgOp,
1898  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1899  /*mnkPaddedSizesNextMultipleOf=*/
1900  getMatmulPaddedSizesNextMultipleOf(),
1901  /*mnkOrder=*/getMatmulInnerDimsOrder());
1902  if (succeeded(packResult)) {
1903  results.push_back(packResult->packedLinalgOp);
1904  continue;
1905  }
1906  results.push_back(linalgOp);
1907  }
1908  transformResults.set(cast<OpResult>(getPackedOp()), results);
1910 }
1911 
1912 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1913  Builder b(getContext());
1914  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1915  b);
1916 }
1917 
1918 void transform::PackGreedilyOp::getEffects(
1920  transform::consumesHandle(getTargetMutable(), effects);
1921  transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1922  transform::producesHandle(getOperation()->getOpResults(), effects);
1923  transform::modifiesPayload(effects);
1924 }
1925 
1926 //===---------------------------------------------------------------------===//
1927 // PackTransposeOp
1928 //===---------------------------------------------------------------------===//
1929 
1930 LogicalResult transform::PackTransposeOp::verify() {
1931  if (!isPermutationVector(getInnerPerm())) {
1932  return emitOpError() << getInnerPermAttrName()
1933  << " is not a valid permutation";
1934  }
1935  if (!isPermutationVector(getOuterPerm())) {
1936  return emitOpError() << getOuterPermAttrName()
1937  << " is not a valid permutation";
1938  }
1939  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1940  return emitOpError() << " at least one of " << getInnerPermAttrName()
1941  << " or " << getOuterPermAttrName()
1942  << " must be specified";
1943  }
1944  return success();
1945 }
1946 
1947 namespace {
1948 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1949 } // namespace
1950 
1951 /// Return true if `permutation` is a valid permutation of the
1952 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1953 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1954 /// This is the case when the `permutation` rank matches the rank expected by
1955 /// `op` and `permutation` is itself a permutation vector.
1956 /// Return true if either `op` or `permutation` are empty to allow a simpler
1957 /// polymorphic implementation.
1958 template <typename RelayoutOpTy>
1960  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1961  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1962  static_assert(
1963  llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1964  "applies to only pack or unpack operations");
1965  if (!op || permutation.empty())
1966  return true;
1967  size_t innerRank = op.getInnerDimsPos().size();
1968  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1969  return permutation.size() == innerRank && isPermutationVector(permutation);
1970  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1971  // Don't rely on it.
1972  if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1973  return permutation.size() == op.getSourceRank() &&
1974  isPermutationVector(permutation);
1975  }
1976  return permutation.size() == op.getDestRank() &&
1977  isPermutationVector(permutation);
1978 }
1979 
1981 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1982  transform::TransformResults &transformResults,
1983  transform::TransformState &state) {
1984  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1985  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1986  // Step 1. If nothing to pack, propagate success.
1987  if (std::empty(packOrUnpackOps)) {
1988  transformResults.set(cast<OpResult>(getPackedOp()), {});
1989  transformResults.set(cast<OpResult>(getPackOp()), {});
1990  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1992  }
1993 
1994  // Step 2. Bunch of runtime sanity check and error messages.
1995  // Step 2.1. Fail on multi-op handles.
1996  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1997  !llvm::hasSingleElement(linalgOps)) {
1998  return emitSilenceableError()
1999  << "requires target to map to exactly 1 "
2000  "packing op and 1 packed op ("
2001  << "got " << llvm::range_size(packOrUnpackOps) << " and "
2002  << llvm::range_size(linalgOps) << ")";
2003  }
2004 
2005  // Step 2.2. Fail on wrong type.
2006  auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
2007  auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
2008  if ((!packOp && !unPackOp)) {
2009  return emitSilenceableError() << "requires target to map to a "
2010  "linalg.pack or linalg.unpack";
2011  }
2012  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
2013  if (!linalgOpTarget)
2014  return emitSilenceableError() << "requires a LinalgOp target";
2015 
2016  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
2017  LinalgOp linalgOp;
2018  if (packOp && packOp.getResult().hasOneUse())
2019  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
2020  else if (unPackOp)
2021  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
2022  if (linalgOp != linalgOpTarget) {
2023  auto errorMsg =
2024  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
2025  : StringLiteral{"not produced by the LinalgOp target"};
2026  return emitSilenceableError() << errorMsg;
2027  }
2028 
2029  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
2030  // PackOp.
2031  if (unPackOp) {
2032  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
2033  OpOperand *packUse = linalgOp.getDpsInitOperand(
2034  cast<OpResult>(unPackOp.getSource()).getResultNumber());
2035  packOp = packUse->get().getDefiningOp<linalg::PackOp>();
2036  if (!packOp || !packOp.getResult().hasOneUse())
2037  return emitSilenceableError() << "could not find matching pack op";
2038  }
2039 
2040  // Step 2.5. Fail if any permutation does not validate.
2041  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
2042  ArrayRef<int64_t> perm =
2043  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
2044  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
2045  ? StringLiteral{"invalid outer_perm"}
2046  : StringLiteral{"invalid inner_perm"};
2047  if (!isValidPackingPermutation(packOp, perm, permType) ||
2048  !isValidPackingPermutation(unPackOp, perm, permType)) {
2049  Operation *packOrUnpackOp =
2050  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
2051  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
2052  }
2053  }
2054 
2055  // From here on, packOp and linalgOp are always present, unPackOp may or may
2056  // not be present.
2057  assert(packOp && linalgOp && "unexpected null op");
2058 
2059  // Step 3. Actually transpose the ops.
2060  FailureOr<PackTransposeResult> res = packTranspose(
2061  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
2062  // Preconditions have been checked, it is an error to fail here.
2063  assert(succeeded(res) && "unexpected packTranspose failure");
2064 
2065  // Step 4. Return results.
2066  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
2067  transformResults.set(cast<OpResult>(getPackedOp()),
2068  {res->transposedLinalgOp});
2069  if (unPackOp) {
2070  transformResults.set(cast<OpResult>(getUnPackOp()),
2071  {res->transposedUnPackOp});
2072  } else {
2073  transformResults.set(cast<OpResult>(getUnPackOp()), {});
2074  }
2075 
2077 }
2078 
2079 //===---------------------------------------------------------------------===//
2080 // PadOp
2081 //===---------------------------------------------------------------------===//
2082 
2083 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
2084  ArrayRef<int64_t> paddingDimensions,
2085  ArrayRef<int64_t> padToMultipleOf,
2086  ArrayRef<int64_t> nofoldFlags,
2087  ArrayRef<Attribute> transposePaddings,
2088  StringRef copyBackOp,
2089  bool usePrescribedTensorShapes) {
2090  auto resultType = transform::AnyOpType::get(b.getContext());
2091  return build(/*odsBuilder=*/b,
2092  /*result=*/result,
2093  /*types=*/TypeRange{resultType, resultType},
2094  /*target=*/target,
2095  /*padding_values=*/ArrayAttr(), // let inference handle this
2096  /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
2097  /*pad_to_multiple_of=*/ValueRange{},
2098  /*padToMultipleOf=*/
2099  (padToMultipleOf.empty()
2100  ? DenseI64ArrayAttr()
2101  : b.getDenseI64ArrayAttr(padToMultipleOf)),
2102  /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
2103  /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
2104  /*copy_back_op=*/b.getStringAttr(copyBackOp),
2105  /*use_prescribed_tensor_shapes=*/
2106  usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
2107 }
2108 
2109 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
2110  ArrayRef<int64_t> paddingDimensions,
2111  ArrayRef<OpFoldResult> mixedPadToMultipleOf,
2112  ArrayRef<int64_t> nofoldFlags,
2113  ArrayRef<Attribute> transposePaddings,
2114  StringRef copyBackOp,
2115  bool usePrescribedTensorShapes) {
2116  auto resultType = transform::AnyOpType::get(b.getContext());
2117  SmallVector<int64_t> staticPadToMultipleOf;
2118  SmallVector<Value> dynamicPadToMultipleOf;
2119  dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
2120  staticPadToMultipleOf);
2121  return build(/*odsBuilder=*/b,
2122  /*result=*/result,
2123  /*types=*/TypeRange{resultType, resultType},
2124  /*target=*/target,
2125  /*padding_values=*/ArrayAttr(), // let inference handle this
2126  /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
2127  /*pad_to_multiple_of=*/dynamicPadToMultipleOf,
2128  /*padToMultipleOf=*/staticPadToMultipleOf,
2129  /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
2130  /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
2131  /*copy_back_op=*/copyBackOp,
2132  /*use_prescribed_tensor_shapes=*/usePrescribedTensorShapes);
2133 }
2134 
2135 void PadOp::getEffects(
2137  consumesHandle(getTargetMutable(), effects);
2138  onlyReadsHandle(getPadToMultipleOfMutable(), effects);
2139  producesHandle(getOperation()->getOpResults(), effects);
2140  modifiesPayload(effects);
2141 }
2142 
2143 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
2144  Builder b(getContext());
2145  return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
2146 }
2147 
2149 transform::PadOp::apply(transform::TransformRewriter &rewriter,
2150  transform::TransformResults &results,
2151  transform::TransformState &state) {
2152  auto transformOp = cast<TransformOpInterface>(getOperation());
2153  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
2154 
2155  for (Operation *target : state.getPayloadOps(getTarget())) {
2156  auto linalgTarget = dyn_cast<LinalgOp>(target);
2157  if (!linalgTarget) {
2158  auto diag = emitSilenceableError() << "expected LinalgOp target";
2159  diag.attachNote(target->getLoc()) << "target op";
2160  return diag;
2161  }
2162 
2163  // Convert the integer packing flags to booleans.
2164  SmallVector<bool> nofoldFlags;
2165  for (int64_t packPadding :
2166  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
2167  nofoldFlags.push_back(static_cast<bool>(packPadding));
2168 
2169  // Convert the padding values to attributes.
2170  SmallVector<Attribute> paddingValues;
2171  for (auto const &[untypedAttr, elementOrTensorType] :
2172  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
2173 
2174  if (isa<ub::PoisonAttr>(untypedAttr)) {
2175  paddingValues.push_back(untypedAttr);
2176  continue;
2177  }
2178  auto attr = dyn_cast<TypedAttr>(untypedAttr);
2179  if (!attr) {
2180  emitOpError("expects padding values to be typed attributes or poison");
2182  }
2183  Type elementType = getElementTypeOrSelf(elementOrTensorType);
2184  // Try to parse string attributes to obtain an attribute of element type.
2185  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2186  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2187  stringAttr, getContext(), elementType,
2188  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2189  if (!parsedAttr || parsedAttr.getType() != elementType) {
2190  auto diag = this->emitOpError("expects a padding that parses to ")
2191  << elementType << ", got " << untypedAttr;
2192  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2194  }
2195  paddingValues.push_back(parsedAttr);
2196  continue;
2197  }
2198  // Otherwise, add the attribute directly.
2199  if (attr.getType() != elementType) {
2200  auto diag = this->emitOpError("expects a padding value of type ")
2201  << elementType << ", got " << attr;
2202  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2204  }
2205  paddingValues.push_back(attr);
2206  }
2207 
2208  // Extract the transpose vectors.
2209  SmallVector<SmallVector<int64_t>> transposePaddings;
2210  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2211  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2212  cast<ArrayAttr>(transposeVector)));
2213 
2214  LinalgOp paddedOp;
2216  options.paddingDimensions =
2217  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2218 
2219  SmallVector<int64_t> padToMultipleOf;
2221  state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2222  if (!status.succeeded())
2223  return status;
2224  if (padToMultipleOf.empty())
2225  padToMultipleOf =
2226  SmallVector<int64_t>(options.paddingDimensions.size(), 1);
2227 
2228  options.padToMultipleOf = padToMultipleOf;
2229  options.paddingValues = paddingValues;
2230  options.nofoldFlags = nofoldFlags;
2231  if (getCopyBackOp() ==
2232  bufferization::MaterializeInDestinationOp::getOperationName()) {
2235  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2237  } else if (getCopyBackOp() == kCopyOpNone) {
2239  } else {
2240  llvm_unreachable("unsupported copy_back op");
2241  }
2242  // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2243  bool irChanged = false;
2244  if (getUsePrescribedTensorShapes() &&
2245  linalgTarget.hasPureTensorSemantics()) {
2246  OpBuilder::InsertionGuard g(rewriter);
2247  rewriter.setInsertionPoint(linalgTarget);
2248  for (OpOperand &operand : linalgTarget->getOpOperands()) {
2249  for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2250  if (ShapedType::isStatic(dim))
2251  continue;
2252  options.setSizeToPadTo(operand.getOperandNumber(), i,
2253  tensor::getMixedSize(rewriter,
2254  operand.get().getLoc(),
2255  operand.get(), i));
2256  irChanged = true;
2257  }
2258  }
2259  }
2260 
2261  SmallVector<Value> replacements;
2262  SmallVector<tensor::PadOp> newPadOps;
2263  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
2264  replacements, newPadOps))) {
2265  if (irChanged) {
2266  auto diag = emitDefiniteFailure() << "failed to pad op";
2267  diag.attachNote(target->getLoc()) << "target op";
2268  return diag;
2269  }
2270  auto diag = emitSilenceableError() << "failed to pad op";
2271  diag.attachNote(target->getLoc()) << "target op";
2272  return diag;
2273  }
2274 
2275  // We need to perform our own replacement here because this API is still
2276  // used in patterns that "pad and hoist", for which the replacement values
2277  // need to be different.
2278  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2279  // that we have more composable abstractions.
2280  rewriter.replaceOp(linalgTarget, replacements);
2281  paddedOps.push_back(paddedOp);
2282  padOps.append(newPadOps.begin(), newPadOps.end());
2283  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2284  for (Value v : replacements) {
2285  Operation *copyBackOp = v.getDefiningOp();
2286  if (!llvm::is_contained(copyBackOps, copyBackOp))
2287  copyBackOps.push_back(copyBackOp);
2288  }
2289  }
2290  }
2291 
2292  results.set(cast<OpResult>(getPadded()), paddedOps);
2293  results.set(cast<OpResult>(getPad()), padOps);
2294  results.set(cast<OpResult>(getCopy()), copyBackOps);
2296 }
2297 
2298 LogicalResult transform::PadOp::verify() {
2299  SmallVector<int64_t> nofoldFlags =
2300  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2301  if (any_of(nofoldFlags, [](int64_t packPadding) {
2302  return packPadding != 0 && packPadding != 1;
2303  })) {
2304  return emitOpError()
2305  << "expects nofold_flags to contain booleans (0/1), found "
2306  << getNofoldFlags();
2307  }
2308 
2309  SmallVector<int64_t> paddingDimensions =
2310  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2311  if (any_of(paddingDimensions,
2312  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2313  return emitOpError() << "expects padding_dimensions to contain positive "
2314  "integers, found "
2315  << getPaddingDimensions();
2316  }
2317  if (!getMixedPadToMultipleOf().empty()) {
2318  if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2319  return emitOpError() << "expects as many multiples as padding_dimensions";
2320  }
2321  }
2322  ArrayAttr transposes = getTransposePaddings();
2323  for (Attribute attr : transposes) {
2324  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2325  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2326  if (!std::is_permutation(sequence.begin(), sequence.end(),
2327  transpose.begin(), transpose.end())) {
2328  return emitOpError()
2329  << "expects transpose_paddings to be a permutation, found "
2330  << attr;
2331  }
2332  }
2333  if (getCopyBackOp() !=
2334  bufferization::MaterializeInDestinationOp::getOperationName() &&
2335  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2336  getCopyBackOp() != kCopyOpNone)
2337  return emitOpError() << "invalid copy_back_op";
2338  return success();
2339 }
2340 
2341 //===---------------------------------------------------------------------===//
2342 // PadTilingInterfaceOp
2343 //===---------------------------------------------------------------------===//
2344 
2345 void transform::PadTilingInterfaceOp::build(OpBuilder &b,
2346  OperationState &result,
2347  Value target,
2348  ArrayRef<int64_t> paddingSizes,
2349  bool padToMultipleOf) {
2350  auto resultType = transform::AnyOpType::get(b.getContext());
2351  return build(/*odsBuilder=*/b,
2352  /*result=*/result,
2353  /*types=*/TypeRange{resultType, resultType},
2354  /*target=*/target,
2355  /*padding_values=*/ArrayAttr(), // let inference handle this
2356  /*padding_sizes=*/ValueRange{},
2357  /*paddingSizes=*/
2358  (paddingSizes.empty() ? DenseI64ArrayAttr()
2359  : b.getDenseI64ArrayAttr(paddingSizes)),
2360  /*pad_to_multiple_of=*/
2361  padToMultipleOf ? b.getUnitAttr() : nullptr);
2362 }
2363 
2364 void transform::PadTilingInterfaceOp::build(
2365  OpBuilder &b, OperationState &result, Value target,
2366  ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
2367  auto resultType = transform::AnyOpType::get(b.getContext());
2368  SmallVector<int64_t> staticPaddingSizes;
2369  SmallVector<Value> dynamicPaddingSizes;
2370  dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
2371  staticPaddingSizes);
2372  return build(/*odsBuilder=*/b,
2373  /*result=*/result,
2374  /*types=*/TypeRange{resultType, resultType},
2375  /*target=*/target,
2376  /*padding_values=*/ArrayAttr(), // let inference handle this
2377  /*padding_sizes=*/dynamicPaddingSizes,
2378  /*paddingSizes=*/staticPaddingSizes,
2379  /*usePrescribedTensorShapes=*/padToMultipleOf);
2380 }
2381 
2382 void transform::PadTilingInterfaceOp::getEffects(
2384  consumesHandle(getTargetMutable(), effects);
2385  onlyReadsHandle(getPaddingSizesMutable(), effects);
2386  producesHandle(getOperation()->getOpResults(), effects);
2387  modifiesPayload(effects);
2388 }
2389 
2391 transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2392  Builder b(getContext());
2393  return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2394 }
2395 
2397 transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2398  transform::TransformResults &results,
2399  transform::TransformState &state) {
2400  SmallVector<Operation *> paddedOps, padOps;
2401 
2402  for (Operation *target : state.getPayloadOps(getTarget())) {
2403  auto targetOp = dyn_cast<TilingInterface>(target);
2404  if (!targetOp) {
2405  auto diag = emitSilenceableError() << "expected TilingInterface target";
2406  diag.attachNote(target->getLoc()) << "target op";
2407  return diag;
2408  }
2409 
2410  // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
2411  // loopsToOperand map / C++ APIs to compute the effect of padding on
2412  // operands.
2413  if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2414  auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
2415  "supported atm";
2416  diag.attachNote(target->getLoc()) << "target op";
2417  return diag;
2418  }
2419 
2420  // Convert the padding values to attributes.
2421  SmallVector<Attribute> paddingValues;
2422  for (auto const &[untypedAttr, elementOrTensorType] :
2423  llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2424  auto attr = dyn_cast<TypedAttr>(untypedAttr);
2425  Type elementType = getElementTypeOrSelf(elementOrTensorType);
2426 
2427  if (isa<ub::PoisonAttr>(untypedAttr)) {
2428  paddingValues.push_back(untypedAttr);
2429  continue;
2430  }
2431  if (!attr) {
2432  emitOpError("expects padding values to be typed attributes or poison");
2434  }
2435  // Try to parse string attributes to obtain an attribute of element type.
2436  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2437  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2438  stringAttr, getContext(), elementType,
2439  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2440  if (!parsedAttr || parsedAttr.getType() != elementType) {
2441  auto diag = this->emitOpError("expects a padding that parses to ")
2442  << elementType << ", got " << attr;
2443  diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2445  }
2446  paddingValues.push_back(parsedAttr);
2447  continue;
2448  }
2449  // Otherwise, add the attribute directly.
2450  if (attr.getType() != elementType) {
2451  auto diag = this->emitOpError("expects a padding value of type ")
2452  << elementType << ", got " << attr;
2453  diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2455  }
2456  paddingValues.push_back(attr);
2457  }
2458 
2459  // Set options.
2460  TilingInterface paddedOp;
2462  options.setPaddingValues(paddingValues)
2463  .setPaddingSizes(getMixedPaddingSizes())
2464  .setPadToMultipleOf(getPadToMultipleOf());
2465 
2466  // Apply padding.
2467  SmallVector<tensor::PadOp> newPadOps;
2468  FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
2469  rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
2470  newPadOps);
2471  if (failed(maybePaddedOp)) {
2472  auto diag = emitSilenceableError() << "failed to pad op";
2473  diag.attachNote(target->getLoc()) << "target op";
2474  return diag;
2475  }
2476 
2477  // Set transform results.
2478  paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
2479  padOps.append(newPadOps.begin(), newPadOps.end());
2480  }
2481 
2482  results.set(cast<OpResult>(getPadded()), paddedOps);
2483  results.set(cast<OpResult>(getPad()), padOps);
2485 }
2486 
2487 LogicalResult transform::PadTilingInterfaceOp::verify() { return success(); }
2488 
2489 //===---------------------------------------------------------------------===//
2490 // HoistPadOp
2491 //===---------------------------------------------------------------------===//
2492 
2493 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2494  transform::TransformRewriter &rewriter,
2495  transform::TransformResults &transformResults,
2496  transform::TransformState &state) {
2497  auto targetOps = state.getPayloadOps(getTarget());
2498  auto loopOps = state.getPayloadOps(getLoop());
2499  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2500  return emitDefiniteFailure()
2501  << "requires exactly one target and one loop handle (got "
2502  << llvm::range_size(targetOps) << " and "
2503  << llvm::range_size(loopOps) << ")";
2504  }
2505 
2506  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2507  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2508  if (!padOp || !loopOp)
2509  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2510 
2511  FailureOr<linalg::detail::PackingResult> result =
2512  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2513  getTranspose());
2514  if (failed(result))
2515  return emitDefiniteFailure() << "could not build packing loop nest";
2516 
2517  if (result->clonedLoopIvs.empty()) {
2518  transformResults.set(cast<OpResult>(getPackingLoop()),
2519  {result->hoistedPadOp.getOperation()});
2521  }
2522  auto outerPackedLoop =
2523  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2524  transformResults.set(cast<OpResult>(getPackingLoop()),
2525  {outerPackedLoop.getOperation()});
2527 }
2528 
2530  ArrayRef<int64_t> transpose = getTranspose();
2531  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2532  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2533  transpose.end())) {
2534  return emitOpError() << "expects transpose to be a permutation, found "
2535  << getTranspose();
2536  }
2537  return success();
2538 }
2539 
2540 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2542  transform::onlyReadsHandle(getTargetMutable(), effects);
2543  transform::onlyReadsHandle(getLoopMutable(), effects);
2544  transform::producesHandle(getOperation()->getOpResults(), effects);
2545  transform::modifiesPayload(effects);
2546 }
2547 
2549 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2550  tensor::PadOp target,
2552  transform::TransformState &state) {
2553  tensor::PadOp hoistedPadOp;
2554  SmallVector<TransposeOp> transposeOps;
2555  FailureOr<Value> result =
2556  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2557  hoistedPadOp, transposeOps);
2558  if (succeeded(result)) {
2559  // We need to perform our own replacement here because this API is still
2560  // used in patterns that "pad and hoist", for which the replacement values
2561  // need to be different.
2562  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2563  // that we have more composable abstractions.
2564  rewriter.replaceOp(target, *result);
2565  results.push_back(hoistedPadOp);
2567  }
2568  return emitDefaultSilenceableFailure(target);
2569 }
2570 
2571 LogicalResult transform::HoistPadOp::verify() {
2572  ArrayRef<int64_t> transpose = getTranspose();
2573  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2574  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2575  transpose.end())) {
2576  return emitOpError() << "expects transpose to be a permutation, found "
2577  << getTranspose();
2578  }
2579  return success();
2580 }
2581 
2582 //===----------------------------------------------------------------------===//
2583 // PromoteOp
2584 //===----------------------------------------------------------------------===//
2585 
2587 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2588  LinalgOp target,
2590  transform::TransformState &state) {
2591  LinalgPromotionOptions promotionOptions;
2592  if (!getOperandsToPromote().empty())
2593  promotionOptions = promotionOptions.setOperandsToPromote(
2594  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2595  if (getUseFullTilesByDefault())
2596  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2597  getUseFullTilesByDefault());
2598  if (getUseOriginalSubviewSize())
2599  promotionOptions =
2600  promotionOptions.setUseOriginalSubviewSize(getUseOriginalSubviewSize());
2601  if (getUseAlloca())
2602  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2603  if (!getUseFullTileBuffers().empty())
2604  promotionOptions = promotionOptions.setUseFullTileBuffers(
2605  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2606  if (getAlignment().has_value())
2607  promotionOptions = promotionOptions.setAlignment(*getAlignment());
2608  if (getMemorySpace().has_value())
2609  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2610 
2611  if (getMapping().has_value()) {
2612  // The mapping should only contain an element
2613  auto mapping = *getMapping();
2614  if (mapping.size() > 1)
2615  return emitDefaultDefiniteFailure(target);
2616 
2617  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2618 
2619  if (addressSpace.getAddressSpace() ==
2620  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2621  promotionOptions =
2622  promotionOptions
2626  .setUseFullTileBuffers({false, false});
2627  } else if (addressSpace.getAddressSpace() ==
2628  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2629  promotionOptions =
2630  promotionOptions
2634  .setUseFullTileBuffers({false, false});
2635  } else {
2636  return emitDefaultDefiniteFailure(target);
2637  }
2638  }
2639 
2640  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2641  return emitDefaultDefiniteFailure(target);
2642 
2643  rewriter.setInsertionPoint(target);
2644  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2645  if (failed(res))
2646  return emitDefaultDefiniteFailure(target);
2647  results.push_back(target);
2649 }
2650 
2651 //===----------------------------------------------------------------------===//
2652 // ReplaceOp
2653 //===----------------------------------------------------------------------===//
2654 
2656 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2657  TransformResults &transformResults,
2658  TransformState &state) {
2659  auto payload = state.getPayloadOps(getTarget());
2660 
2661  // Check for invalid targets.
2662  for (Operation *target : payload) {
2663  if (target->getNumOperands() > 0)
2664  return emitDefiniteFailure() << "expected target without operands";
2665  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2666  target->getNumRegions() > 0)
2667  return emitDefiniteFailure()
2668  << "expected target that is isolated from above";
2669  }
2670 
2671  // Clone and replace.
2672  Operation *pattern = &getBodyRegion().front().front();
2673  SmallVector<Operation *> replacements;
2674  for (Operation *target : payload) {
2675  if (getOperation()->isAncestor(target))
2676  continue;
2677  rewriter.setInsertionPoint(target);
2678  Operation *replacement = rewriter.clone(*pattern);
2679  rewriter.replaceOp(target, replacement->getResults());
2680  replacements.push_back(replacement);
2681  }
2682  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2684 }
2685 
2686 void transform::ReplaceOp::getEffects(
2688  consumesHandle(getTargetMutable(), effects);
2689  producesHandle(getOperation()->getOpResults(), effects);
2690  modifiesPayload(effects);
2691 }
2692 
2693 LogicalResult transform::ReplaceOp::verify() {
2694  if (!getBodyRegion().hasOneBlock())
2695  return emitOpError() << "expected one block";
2696  if (std::distance(getBodyRegion().front().begin(),
2697  getBodyRegion().front().end()) != 1)
2698  return emitOpError() << "expected one operation in block";
2699  Operation *replacement = &getBodyRegion().front().front();
2700  if (replacement->getNumOperands() > 0)
2701  return replacement->emitOpError()
2702  << "expected replacement without operands";
2703  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2704  replacement->getNumRegions() > 0)
2705  return replacement->emitOpError()
2706  << "expect op that is isolated from above";
2707  return success();
2708 }
2709 
2710 //===----------------------------------------------------------------------===//
2711 // ScalarizeOp
2712 //===----------------------------------------------------------------------===//
2713 
2715 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2716  LinalgOp target,
2718  transform::TransformState &state) {
2719  scf::SCFTilingOptions tilingOptions;
2720  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2721  SmallVector<OpFoldResult> tileSizes;
2722  Location loc = target.getLoc();
2723  SmallVector<OpFoldResult> allShapeSizes =
2724  target.createFlatListOfOperandDims(b, loc);
2725  AffineMap map = target.getShapesToLoopsMap();
2726  if (!map)
2727  return tileSizes;
2728  SmallVector<OpFoldResult> shapeSizes =
2730  allShapeSizes);
2731  // If the shape size is dynamic, tile by 1.
2732  // Otherwise, do not tile (i.e. tile size 0).
2733  for (OpFoldResult shapeSize : shapeSizes) {
2734  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2735  : b.getIndexAttr(1));
2736  }
2737  return tileSizes;
2738  });
2739  rewriter.setInsertionPoint(target);
2740  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2741  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2742  if (failed(maybeTilingResult))
2743  return emitDefaultDefiniteFailure(target);
2744 
2745  if (target->getNumResults())
2746  rewriter.replaceOp(target, maybeTilingResult->replacements);
2747  else
2748  rewriter.eraseOp(target);
2749 
2750  results.reserve(maybeTilingResult->tiledOps.size());
2751  for (Operation *tiled : maybeTilingResult->tiledOps)
2752  results.push_back(tiled);
2754 }
2755 
2756 //===----------------------------------------------------------------------===//
2757 // ConvertToLoopsOp
2758 //===----------------------------------------------------------------------===//
2759 
2761 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2762  transform::TransformResults &results,
2763  transform::TransformState &state) {
2765  for (Operation *target : state.getPayloadOps(getTarget())) {
2766  auto tilingOp = dyn_cast<TilingInterface>(*target);
2767  if (!tilingOp) {
2769  emitSilenceableError()
2770  << "expected the payload to implement TilingInterface";
2771  diag.attachNote(target->getLoc()) << "payload op";
2772  return diag;
2773  }
2774  rewriter.setInsertionPoint(target);
2775  FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2776  scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2777  if (failed(generatedLoops))
2778  return emitDefaultDefiniteFailure(target);
2779  for (scf::ForOp &loop : *generatedLoops) {
2780  loops.push_back(loop.getOperation());
2781  }
2782  rewriter.eraseOp(target);
2783  }
2784  results.set(cast<OpResult>(getResult()), loops);
2786 }
2787 
2788 //===----------------------------------------------------------------------===//
2789 // RewriteInDestinationPassingStyleOp
2790 //===----------------------------------------------------------------------===//
2791 
2793 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2794  transform::TransformRewriter &rewriter, Operation *target,
2796  transform::TransformState &state) {
2797  rewriter.setInsertionPoint(target);
2798  FailureOr<Operation *> maybeResult =
2800  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2801  [&rewriter](auto op) {
2802  return rewriteInDestinationPassingStyle(rewriter, op);
2803  });
2804  if (failed(maybeResult))
2805  return emitDefaultSilenceableFailure(target);
2806  results.push_back(*maybeResult);
2808 }
2809 
2810 //===----------------------------------------------------------------------===//
2811 // SplitOp
2812 //===----------------------------------------------------------------------===//
2813 
2815 SplitOp::apply(transform::TransformRewriter &rewriter,
2816  TransformResults &results, TransformState &state) {
2817  // Collect the dynamic split points if provided.
2818  SmallVector<Operation *> payload =
2819  llvm::to_vector(state.getPayloadOps(getTarget()));
2820 
2821  bool isMultiwaySplit = getMultiway();
2822 
2823  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2824  return mlir::emitSilenceableFailure(getLoc())
2825  << "requires exactly one target when "
2826  "multiway split is enabled (got "
2827  << llvm::range_size(payload) << ")";
2828  }
2829 
2830  SmallVector<OpFoldResult> chunkSizes;
2831 
2832  if (!isMultiwaySplit)
2833  chunkSizes.reserve(payload.size());
2834 
2835  if (getDynamicChunkSizes()) {
2837  if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2838  chunkSizes = llvm::to_vector(llvm::map_range(
2839  state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2840  if (op->getNumResults() != 1 ||
2841  !op->getResult(0).getType().isIndex()) {
2842  diag = emitSilenceableError()
2843  << "expected dynamic split point handle to point to a "
2844  "single-result index-typed op";
2845  diag.attachNote(op->getLoc()) << "dynamic split point";
2846  }
2847  return OpFoldResult(op->getResult(0));
2848  }));
2849  } else {
2850  chunkSizes = llvm::to_vector(
2851  llvm::map_range(state.getParams(getDynamicChunkSizes()),
2852  [](Attribute attr) { return OpFoldResult(attr); }));
2853  }
2854  if (diag.isSilenceableFailure())
2855  return diag;
2856 
2857  // For multiway split, a single payload is expected to have multiple
2858  // split points.
2859  if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2860  return emitDefiniteFailure()
2861  << "expected the dynamic split point handle to point to as "
2862  "many operations ("
2863  << chunkSizes.size() << ") as the target handle ("
2864  << payload.size() << ")";
2865  }
2866  } else {
2867  chunkSizes.resize(payload.size(),
2868  rewriter.getIndexAttr(getStaticChunkSizes()));
2869  }
2870 
2871  auto checkStructuredOpAndDimensions =
2872  [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2873  if (!linalgOp) {
2874  auto diag = emitSilenceableError() << "only applies to structured ops";
2875  diag.attachNote(loc) << "target op";
2876  return diag;
2877  }
2878 
2879  if (getDimension() >= linalgOp.getNumLoops()) {
2880  auto diag = emitSilenceableError() << "dimension " << getDimension()
2881  << " does not exist in target op";
2882  diag.attachNote(loc) << "target op";
2883  return diag;
2884  }
2886  };
2887 
2888  auto checkFailureInSplitting =
2889  [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2890  if (hasFailed) {
2891  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2892  diag.attachNote(loc) << "target op";
2893  return diag;
2894  }
2896  };
2897 
2898  SmallVector<Operation *> opList;
2899  if (isMultiwaySplit) {
2900 
2901  // Split a single target operation at multiple points.
2902  TilingInterface head, tail;
2903  Operation *target = payload.front();
2904 
2905  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2906 
2907  // Check that the target is a valid LinalgOp with correct dimensions.
2909  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2910  if (diag.isSilenceableFailure())
2911  return diag;
2912 
2913  for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2914 
2915  if (idx > 0)
2916  target = tail.getOperation();
2917 
2918  if (!target)
2919  break;
2920 
2921  linalgOp = cast<LinalgOp>(target);
2922  Location loc = target->getLoc();
2923 
2924  rewriter.setInsertionPoint(linalgOp);
2925  std::tie(head, tail) = linalg::splitOp(
2926  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2927  getDimension(), chunkSize);
2928 
2929  // Propagate errors.
2931  checkFailureInSplitting(!head && !tail, loc);
2932  if (diag.isDefiniteFailure())
2933  return diag;
2934 
2935  opList.push_back(head.getOperation());
2936  }
2937 
2938  // Append any leftover parts to the end of the result list.
2939  if (tail)
2940  opList.push_back(tail.getOperation());
2941 
2942  } else {
2943  // Split each target operation.
2944  SmallVector<Operation *> first, second;
2945  Operation *noSecondPart = nullptr;
2946  for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2947  Operation *target = std::get<0>(pair);
2948  Location loc = target->getLoc();
2949  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2951  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2952 
2953  if (diag.isSilenceableFailure())
2954  return diag;
2955 
2956  rewriter.setInsertionPoint(linalgOp);
2957  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2958  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2959  getDimension(), std::get<1>(pair));
2960 
2961  // Propagate errors.
2962  DiagnosedSilenceableFailure diagSplit =
2963  checkFailureInSplitting(!first.back() && !second.back(), loc);
2964  if (diagSplit.isDefiniteFailure())
2965  return diag;
2966 
2967  // Do not add null second parts.
2968  if (!second.back()) {
2969  noSecondPart = target;
2970  second.pop_back();
2971  }
2972  }
2973 
2974  if (second.size() != first.size() && !second.empty()) {
2975  auto diag = emitSilenceableError()
2976  << "splitting does not produce the second part for a subset "
2977  "of targets";
2978  diag.attachNote()
2979  << "expected splitting to produce the second part of all "
2980  "or none of the targets";
2981  diag.attachNote(noSecondPart->getLoc())
2982  << "first target with no second part";
2983  return diag;
2984  }
2985 
2986  opList.append(first);
2987  if (!second.empty())
2988  opList.append(second);
2989  }
2990  results.set(cast<OpResult>(getSplitList()), opList);
2992 }
2993 
2994 void SplitOp::getEffects(
2996  consumesHandle(getTargetMutable(), effects);
2997  if (getDynamicChunkSizes())
2998  onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
2999  producesHandle(getOperation()->getOpResults(), effects);
3000  modifiesPayload(effects);
3001 }
3002 
3003 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
3004  OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
3005  IntegerAttr staticChunkSizes;
3006  if (parser.parseOperand(target) || parser.parseKeyword("after"))
3007  return failure();
3008 
3009  OptionalParseResult dynamicPointParseResult =
3010  parser.parseOptionalOperand(dynamicChunkSizes);
3011  if (!dynamicPointParseResult.has_value()) {
3012  int64_t staticChunkSizesValue;
3013  if (failed(parser.parseInteger(staticChunkSizesValue)))
3014  return failure();
3015 
3016  staticChunkSizes =
3017  parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
3018  }
3019 
3020  Type targetType;
3021  if (parser.parseOptionalAttrDict(result.attributes) ||
3022  parser.parseColonType(targetType) ||
3023  parser.resolveOperand(target, targetType, result.operands)) {
3024  return failure();
3025  }
3026  if (dynamicPointParseResult.has_value()) {
3027  Type ChunkSizesType;
3028  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
3029  parser.parseType(ChunkSizesType) ||
3030  parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
3031  result.operands)) {
3032  return failure();
3033  }
3034 
3035  staticChunkSizes =
3036  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
3037  }
3038 
3039  result.addAttribute(
3040  SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
3041  staticChunkSizes);
3042  result.addTypes(targetType);
3043  return success();
3044 }
3045 
3046 void SplitOp::print(OpAsmPrinter &printer) {
3047  printer << " " << getTarget() << " after ";
3048  int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
3049  if (staticChunkSize != ShapedType::kDynamic)
3050  printer << staticChunkSize;
3051  else
3052  printer << getDynamicChunkSizes();
3053  printer << " ";
3054  printer.printOptionalAttrDict(getOperation()->getAttrs(),
3055  {getStaticChunkSizesAttrName()});
3056  printer << " : " << getTarget().getType();
3057  if (staticChunkSize == ShapedType::kDynamic)
3058  printer << ", " << getDynamicChunkSizes().getType();
3059 }
3060 
3061 LogicalResult SplitOp::verify() {
3062  if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
3063  (getDynamicChunkSizes() == nullptr)) {
3064  return emitOpError() << "expects either a dynamic or a static split "
3065  "point to be provided";
3066  }
3067  return success();
3068 }
3069 
3070 //===----------------------------------------------------------------------===//
3071 // SplitReductionOp
3072 //===----------------------------------------------------------------------===//
3073 
3074 void transform::SplitReductionOp::build(
3075  OpBuilder &builder, OperationState &result, Value target,
3076  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
3077  bool useScalingAlgorithm, bool useAlloc) {
3078  MLIRContext *ctx = builder.getContext();
3079  result.addOperands(target);
3080  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
3081  builder.getI64IntegerAttr(splitFactor));
3082  result.addAttribute(
3083  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
3084  builder.getI64IntegerAttr(insertSplitDimension));
3085  if (innerParallel) {
3086  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
3087  builder.getUnitAttr());
3088  }
3089  if (useScalingAlgorithm) {
3090  result.addAttribute(
3091  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
3092  builder.getUnitAttr());
3093  }
3094  if (useAlloc) {
3095  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
3096  builder.getUnitAttr());
3097  }
3098  auto resultType = transform::AnyOpType::get(ctx);
3099  result.addTypes({resultType, resultType, resultType, resultType});
3100 }
3101 
3102 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
3103  transform::TransformRewriter &rewriter, LinalgOp target,
3105  transform::TransformState &state) {
3106  ControlSplitReductionFn splitFn = [&](LinalgOp) {
3107  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
3108  unsigned(getInsertSplitDimension()),
3109  bool(getInnerParallel())};
3110  };
3111  rewriter.setInsertionPoint(target);
3112  FailureOr<SplitReductionResult> splitResult =
3113  (getUseScalingAlgorithm())
3114  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
3115  : splitReduction(rewriter, target, splitFn, getUseAlloc());
3116  if (failed(splitResult))
3117  return emitDefaultDefiniteFailure(target);
3118 
3119  results.push_back(splitResult->initOrAlloc);
3120  results.push_back(splitResult->fillOp);
3121  results.push_back(splitResult->splitLinalgOp);
3122  results.push_back(splitResult->resultCombiningLinalgOp);
3124 }
3125 
3126 //===----------------------------------------------------------------------===//
3127 // TileReductionUsingForOp
3128 //===----------------------------------------------------------------------===//
3129 
3130 void transform::TileReductionUsingForOp::build(
3131  OpBuilder &builder, OperationState &result, Value target,
3132  ArrayRef<int64_t> staticTileSizes) {
3133  // Call the default builder.
3134  // This is future-proof re mixed static-dynamic and setting up the proper
3135  // operands segment sizes attributes for multiple variadic operands.
3136  // In the absence of this, horrible bugs ensue.
3137  // TODO: support mixed static-dynamic (see TileUsingForallOp).
3138  MLIRContext *ctx = builder.getContext();
3139  auto opTy = transform::AnyOpType::get(ctx);
3140  auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
3141  build(builder, result,
3142  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3143  /*target=*/target,
3144  /*reduction_dims=*/nullptr,
3145  /*tile_sizes=*/staticTileSizesAttr);
3146 }
3147 
3148 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
3149  transform::TransformRewriter &rewriter, Operation *target,
3151  transform::TransformState &state) {
3152  rewriter.setInsertionPoint(target);
3153 
3154  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3155  if (!partialReductionOp) {
3156  return emitSilenceableFailure(
3157  target->getLoc(),
3158  "Operation should implement PartialReductionOpInterface");
3159  }
3160 
3161  SmallVector<unsigned> reductionDims =
3162  extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3163  if (reductionDims.empty()) {
3164  for (auto [idx, iteratorType] :
3165  llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3166  if (iteratorType == utils::IteratorType::reduction)
3167  reductionDims.push_back(idx);
3168  }
3169  }
3170 
3173  options.setReductionTilingStrategy(
3175  options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
3176  options.setReductionDims(reductionDims);
3177  FailureOr<scf::SCFTilingResult> result =
3178  scf::tileUsingSCF(rewriter, partialReductionOp, options);
3179 
3180  if (failed(result)) {
3181  return emitSilenceableFailure(getLoc(),
3182  "failed to tile using partial reduction");
3183  }
3184  rewriter.replaceOp(target, result->replacements);
3185  for (Value initValue : result->initialValues)
3186  results.push_back(initValue.getDefiningOp());
3187  for (auto parallelTiledOp : result->tiledOps)
3188  results.push_back(parallelTiledOp);
3189  for (auto mergeOp : result->mergeOps)
3190  results.push_back(mergeOp);
3191  results.push_back(result->loops.front());
3193 }
3194 
3195 //===----------------------------------------------------------------------===//
3196 // TileReductionUsingForallOp
3197 //===----------------------------------------------------------------------===//
3198 
3199 void transform::TileReductionUsingForallOp::build(
3200  OpBuilder &builder, OperationState &result, Value target,
3201  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3202  ArrayAttr mapping) {
3203  // Call the default builder.
3204  // This is future-proof re mixed static-dynamic and setting up the proper
3205  // operands segment sizes attributes for multiple variadic operands.
3206  // In the absence of this, horrible bugs ensue.
3207  // TODO: support mixed static-dynamic (see TileUsingForallOp).
3208  MLIRContext *ctx = builder.getContext();
3209  auto opTy = transform::AnyOpType::get(ctx);
3210  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3211  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3212  build(builder, result,
3213  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3214  /*target=*/target,
3215  /*reduction_dims=*/{},
3216  /*num_threads=*/staticNumThreadsAttr,
3217  /*tile_sizes=*/staticTileSizesAttr,
3218  /*mapping=*/mapping);
3219 }
3220 
3221 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3222  transform::TransformRewriter &rewriter, Operation *target,
3224  transform::TransformState &state) {
3225  rewriter.setInsertionPoint(target);
3226 
3227  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3228  if (!partialReductionOp) {
3229  return emitSilenceableFailure(
3230  target->getLoc(),
3231  "Operation should implement PartialReductionOpInterface");
3232  }
3233  SmallVector<OpFoldResult> numThreads =
3234  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
3235  SmallVector<OpFoldResult> tileSizes =
3237 
3240  options.setReductionTilingStrategy(
3242  if (!getNumThreads().empty()) {
3243  options.setNumThreads(numThreads);
3244  } else {
3245  options.setTileSizes(tileSizes);
3246  }
3247  if (auto mapping = getMapping()) {
3248  options.setMapping(mapping.value().getValue());
3249  }
3250  SmallVector<unsigned> reductionDims =
3251  extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3252  if (reductionDims.empty()) {
3253  for (auto [idx, iteratorType] :
3254  llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3255  if (iteratorType == utils::IteratorType::reduction)
3256  reductionDims.push_back(idx);
3257  }
3258  }
3259  options.setReductionDims(reductionDims);
3260  FailureOr<scf::SCFTilingResult> result =
3261  scf::tileUsingSCF(rewriter, partialReductionOp, options);
3262 
3263  if (failed(result)) {
3264  auto diag = emitSilenceableError() << "could not tile reduction";
3265  return diag;
3266  }
3267  rewriter.replaceOp(target, result->replacements);
3268 
3269  for (Value initValue : result->initialValues)
3270  results.push_back(initValue.getDefiningOp());
3271  for (auto parallelTiledOp : result->tiledOps)
3272  results.push_back(parallelTiledOp);
3273  for (auto mergeOp : result->mergeOps)
3274  results.push_back(mergeOp);
3275  results.push_back(result->loops.front());
3277 }
3278 
3279 //===----------------------------------------------------------------------===//
3280 // ContinuousTileSizesOp
3281 //===----------------------------------------------------------------------===//
3282 
3284 transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3285  TransformResults &transformResults,
3286  TransformState &state) {
3287 
3288  SmallVector<Operation *> targetOps =
3289  llvm::to_vector(state.getPayloadOps(getTarget()));
3290 
3291  if (!llvm::hasSingleElement(targetOps)) {
3292  return mlir::emitSilenceableFailure(getLoc())
3293  << "requires exactly one target (got " << llvm::range_size(targetOps)
3294  << ")";
3295  }
3296 
3297  Operation *target = *targetOps.begin();
3298  auto linalgOp = dyn_cast<LinalgOp>(target);
3299  auto tileableOp = dyn_cast<TilingInterface>(target);
3300 
3301  if (!linalgOp)
3302  return emitDefiniteFailure() << "expected Linalg Op";
3303 
3304  OpBuilder builder(linalgOp.getContext());
3305 
3306  if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
3307  if (linalgOp.hasDynamicShape()) {
3308  auto diag = emitSilenceableError()
3309  << "cannot compute parametric tile sizes for dynamically "
3310  "shaped payload op";
3311  diag.attachNote(linalgOp->getLoc()) << "payload op";
3312  return diag;
3313  }
3314 
3315  FailureOr<StaticContinuousTileSizeSpecification> spec =
3316  computeStaticContinuousTileSizes(linalgOp, getDimension(),
3317  getTargetSize());
3318  if (failed(spec)) {
3319  return emitSilenceableError()
3320  << "failed to compute multi-size tiling sizes";
3321  }
3322 
3323  SmallVector<int64_t> chunkSizes;
3324 
3325  for (auto &&[tileSize, tripCount] :
3326  llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3327  chunkSizes.push_back(tileSize * tripCount);
3328 
3329  auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3330  return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3331  return builder.getI64IntegerAttr(value);
3332  });
3333  };
3334  transformResults.setParams(cast<OpResult>(getTileSizes()),
3335  getI64AttrsFromI64(spec->tileSizes));
3336  transformResults.setParams(cast<OpResult>(getChunkSizes()),
3337  getI64AttrsFromI64(chunkSizes));
3338 
3340  }
3341 
3342  builder.setInsertionPoint(linalgOp);
3343 
3344  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
3345  unsigned dimension = getDimension();
3346 
3347  FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
3348  builder, tileableOp, dimension, targetSize, true);
3349  if (failed(spec)) {
3350  return emitSilenceableError() << "could not generate tile size computation";
3351  }
3352 
3353  AffineExpr s0 = builder.getAffineSymbolExpr(0);
3354  AffineExpr s1 = builder.getAffineSymbolExpr(1);
3355  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3356  return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
3357  ofrs);
3358  };
3359 
3360  SmallVector<Value> chunkSizes;
3361  Value splitPoint;
3362  for (auto &&[tileSize, tripCount] :
3363  llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3364  splitPoint = apply(s0 * s1, {tileSize, tripCount});
3365  chunkSizes.push_back(splitPoint);
3366  }
3367 
3368  auto getDefiningOps = [&](ArrayRef<Value> values) {
3369  return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3370  return value.getDefiningOp();
3371  });
3372  };
3373 
3374  transformResults.set(cast<OpResult>(getTileSizes()),
3375  getDefiningOps(spec->tileSizes));
3376  transformResults.set(cast<OpResult>(getChunkSizes()),
3377  getDefiningOps(chunkSizes));
3378 
3380 }
3381 
3383 
3384  if (getTileSizes().getType() != getChunkSizes().getType()) {
3385  return emitOpError() << "expects all results type to be the same";
3386  }
3387 
3388  return success();
3389 }
3390 
3391 void transform::ContinuousTileSizesOp::getEffects(
3393  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
3394  onlyReadsPayload(effects);
3395  else
3396  modifiesPayload(effects);
3397  onlyReadsHandle(getTargetMutable(), effects);
3398  producesHandle(getOperation()->getOpResults(), effects);
3399 }
3400 
3402  Type targetType, Type tile_sizes,
3403  Type) {
3404  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
3405 }
3406 
3407 static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
3408  Type &targetType,
3409  Type &tileSizesType,
3410  Type &chunkSizesType) {
3411  FunctionType funcType;
3412  llvm::SMLoc typeLoc = parser.getCurrentLocation();
3413  if (failed(parser.parseType<FunctionType>(funcType)))
3414  return failure();
3415 
3416  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3417  parser.emitError(typeLoc) << "expects a trailing functional type with one "
3418  "argument and one result";
3419  }
3420  targetType = funcType.getInput(0);
3421  tileSizesType = chunkSizesType = funcType.getResult(0);
3422 
3423  return success();
3424 }
3425 
3426 //===----------------------------------------------------------------------===//
3427 // TileUsingForOp
3428 //===----------------------------------------------------------------------===//
3429 
3430 void transform::TileUsingForOp::build(
3431  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3432  Value target, ArrayRef<int64_t> staticTileSizes,
3433  ArrayRef<int64_t> interchange,
3434  std::optional<ArrayRef<bool>> scalableSizes) {
3435  return build(builder, result, loopTypes,
3436  /*target=*/target,
3437  /*mixedTileSizes=*/
3438  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3439  interchange, scalableSizes);
3440 }
3441 
3442 void transform::TileUsingForOp::build(
3443  OpBuilder &builder, OperationState &result, Value target,
3444  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3445  std::optional<ArrayRef<bool>> scalableSizes) {
3446  build(builder, result, target,
3447  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3448  interchange, scalableSizes);
3449 }
3450 
3451 void transform::TileUsingForOp::build(
3452  OpBuilder &builder, OperationState &result, Value target,
3453  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3454  std::optional<ArrayRef<bool>> scalableSizes) {
3455  // Loop types are automaticaly splat by the callee, setting up one is
3456  // enough.
3457  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
3458  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3459  scalableSizes);
3460 }
3461 
3462 void transform::TileUsingForOp::build(
3463  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3464  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
3465  ArrayRef<int64_t> interchange,
3466  std::optional<ArrayRef<bool>> scalableSizes) {
3467  SmallVector<int64_t> staticTileSizes;
3468  SmallVector<Value> dynamicTileSizes;
3469  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3470  // Call the default builder which sets up the proper operands segment sizes
3471  // attributes for multiple variadic operands. In the absence of this,
3472  // horrible bugs ensue.
3473  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3474  unsigned numExpectedLoops =
3475  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3476  SmallVector<Type> resultTypes;
3477  resultTypes.reserve(numExpectedLoops);
3478  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3479  "expected one loop type or as many as loops");
3480  if (loopTypes.size() == 1)
3481  resultTypes.append(numExpectedLoops, loopTypes[0]);
3482  else
3483  llvm::append_range(resultTypes, loopTypes);
3484  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
3485  if (scalableSizes.has_value())
3486  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3487  build(builder, result, /*tiled_linalg_op=*/target.getType(),
3488  /*loops=*/resultTypes,
3489  /*target=*/target,
3490  /*dynamic_sizes=*/dynamicTileSizes,
3491  /*static_sizes=*/staticTileSizesAttr,
3492  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
3493  /*scalable_sizes=*/expandedScalableSizes);
3494 }
3495 
3496 LogicalResult transform::TileUsingForOp::verify() {
3497  if (getMixedSizes().size() != getScalableSizes().size())
3498  return emitOpError("expected same number of sizes (")
3499  << getMixedSizes().size() << ") and scalable sizes ("
3500  << getScalableSizes().size() << ")";
3501  ArrayRef<int64_t> staticSizes = getStaticSizes();
3502  unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3503  if (getLoops().size() != numExpectedLoops)
3504  return emitOpError("expected number of loops to tile (")
3505  << numExpectedLoops << ") to match number of `loops` results ("
3506  << getLoops().size() << ")";
3507  return success();
3508 }
3509 
3511 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3512  TransformResults &transformResults,
3513  TransformState &state) {
3514  ArrayRef<int64_t> tileSizes = getStaticSizes();
3515 
3516  SmallVector<Operation *> targets =
3517  llvm::to_vector(state.getPayloadOps(getTarget()));
3518  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3520  dynamicSizeProducers.reserve(getDynamicSizes().size());
3521  paramSizes.reserve(getDynamicSizes().size());
3522  for (Value transformValue : getDynamicSizes()) {
3523  if (isa<ParamType>(transformValue.getType())) {
3524  dynamicSizeProducers.push_back({});
3525  ArrayRef<Attribute> params = state.getParams(transformValue);
3526  paramSizes.push_back(
3527  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
3528  return cast<IntegerAttr>(attr).getValue().getSExtValue();
3529  })));
3530 
3531  if (paramSizes.back().size() != targets.size()) {
3533  emitSilenceableError()
3534  << "expected as many parameter values ("
3535  << dynamicSizeProducers.back().size() << ") as target ops ("
3536  << targets.size() << ")";
3537  diag.attachNote(transformValue.getLoc()) << "for this parameter";
3538  return diag;
3539  }
3540 
3541  continue;
3542  }
3543  paramSizes.push_back({});
3544  dynamicSizeProducers.push_back(
3545  llvm::to_vector(state.getPayloadOps(transformValue)));
3546 
3547  if (dynamicSizeProducers.back().size() != targets.size()) {
3549  emitSilenceableError()
3550  << "expected as many dynamic size-producing operations ("
3551  << dynamicSizeProducers.back().size() << ") as target ops ("
3552  << targets.size() << ")";
3553  diag.attachNote(transformValue.getLoc()) << "for this handle";
3554  return diag;
3555  }
3556 
3557  for (Operation *op : dynamicSizeProducers.back()) {
3558  if (op->getNumResults() == 1 &&
3559  isa<IndexType>(op->getResult(0).getType())) {
3560  continue;
3561  }
3562 
3564  emitSilenceableError() << "expected sizes to be produced by ops "
3565  "with a single index-type result";
3566  diag.attachNote(op->getLoc()) << "size producer op";
3567  diag.attachNote(transformValue.getLoc()) << "for this handle";
3568  return diag;
3569  }
3570  }
3571 
3574  loops.resize(getLoops().size());
3575  auto scalableSizes = getScalableSizes();
3576  for (auto [i, op] : llvm::enumerate(targets)) {
3577  auto tilingInterface = dyn_cast<TilingInterface>(op);
3578  if (!tilingInterface) {
3580  emitSilenceableError()
3581  << "only ops implementing TilingInterface are supported";
3582  diag.attachNote(op->getLoc()) << "target op";
3583  return diag;
3584  }
3585  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3587  emitSilenceableError()
3588  << "too many tiles provided, expected at most "
3589  << tilingInterface.getLoopIteratorTypes().size() << " found "
3590  << tileSizes.size();
3591  diag.attachNote(op->getLoc()) << "target op";
3592  return diag;
3593  }
3594 
3595  scf::SCFTilingOptions tilingOptions;
3596  if (tileSizes.empty()) {
3597  tilingOptions.setTileSizeComputationFunction(
3599  return {};
3600  });
3601  } else {
3602  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3603  Operation *) {
3605  sizes.reserve(tileSizes.size());
3606  unsigned dynamicIdx = 0;
3607 
3608  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3609  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3610  if (scalableSizes[ofrIdx]) {
3611  auto val = arith::ConstantIndexOp::create(
3612  b, getLoc(), cast<IntegerAttr>(attr).getInt());
3613  Value vscale =
3614  vector::VectorScaleOp::create(b, getLoc(), b.getIndexType());
3615  sizes.push_back(
3616  arith::MulIOp::create(b, getLoc(), val, vscale).getResult());
3617  } else {
3618  sizes.push_back(attr);
3619  }
3620  continue;
3621  }
3622  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3623  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3624  ++dynamicIdx;
3625  assert((dynamicSizes.empty() ^ params.empty()) &&
3626  "expected either dynamic sizes or parameters");
3627  if (!params.empty()) {
3628  sizes.push_back(b.getIndexAttr(params[index]));
3629  } else {
3630  sizes.push_back(dynamicSizes[index]->getResult(0));
3631  }
3632  }
3633  return sizes;
3634  });
3635  }
3636 
3637  tilingOptions.setInterchange(getInterchange());
3638  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3639  tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3640  if (failed(maybeTilingResult))
3642 
3643  rewriter.replaceOp(op, maybeTilingResult->replacements);
3644 
3645  tiled.append(maybeTilingResult->tiledOps);
3646  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3647  loops[en2.index()].push_back(en2.value());
3648  }
3649 
3650  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3651  for (const auto &en : llvm::enumerate(loops))
3652  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3653 
3655 }
3656 
3658  ValueRange dynamic = getDynamicSizes();
3659  ArrayRef<int64_t> tileSizes = getStaticSizes();
3660  SmallVector<OpFoldResult> results;
3661  results.reserve(tileSizes.size());
3662  unsigned dynamicPos = 0;
3663  Builder builder(getContext());
3664  for (int64_t size : tileSizes) {
3665  if (size == ShapedType::kDynamic) {
3666  results.push_back(dynamic[dynamicPos++]);
3667  } else {
3668  results.push_back(builder.getIndexAttr(size));
3669  }
3670  }
3671  return results;
3672 }
3673 
3674 void transform::TileUsingForOp::getEffects(
3676  consumesHandle(getTargetMutable(), effects);
3677  onlyReadsHandle(getDynamicSizesMutable(), effects);
3678  producesHandle(getOperation()->getOpResults(), effects);
3679  modifiesPayload(effects);
3680 }
3681 
3682 //===----------------------------------------------------------------------===//
3683 // TileUsingForallOp
3684 //===----------------------------------------------------------------------===//
3685 
3686 void transform::TileUsingForallOp::build(OpBuilder &builder,
3687  OperationState &result, Value target,
3688  ArrayRef<int64_t> staticTileSizes,
3690  ArrayAttr mapping) {
3691  return build(builder, result,
3692  /*target=*/target,
3693  /*mixedTileSizes=*/
3694  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3695  /*_=*/TileSizesSpec(),
3696  /*mapping=*/mapping);
3697 }
3698 
3699 void transform::TileUsingForallOp::build(OpBuilder &builder,
3700  OperationState &result, Value target,
3701  ArrayRef<OpFoldResult> mixedTileSizes,
3703  ArrayAttr mapping) {
3704  SmallVector<int64_t> staticTileSizes;
3705  SmallVector<Value> dynamicTileSizes;
3706  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3707  // Call the default builder which sets up the proper operands segment sizes
3708  // attributes for multiple variadic operands. In the absence of this,
3709  // horrible bugs ensue.
3710  MLIRContext *ctx = builder.getContext();
3711  auto operationType = transform::AnyOpType::get(ctx);
3712  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3713  build(builder, result,
3714  /*resultTypes=*/TypeRange{operationType, operationType},
3715  /*target=*/target,
3716  /*num_threads=*/ValueRange{},
3717  /*tile_sizes=*/dynamicTileSizes,
3718  /*packed_num_threads=*/Value(),
3719  /*packed_tile_sizes=*/Value(),
3720  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3721  /*static_tile_sizes=*/staticTileSizesAttr,
3722  /*mapping=*/mapping);
3723 }
3724 
3725 void transform::TileUsingForallOp::build(OpBuilder &builder,
3726  OperationState &result, Value target,
3727  ArrayRef<int64_t> staticNumThreads,
3729  ArrayAttr mapping) {
3730  return build(builder, result, target,
3731  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3732  NumThreadsSpec(), mapping);
3733 }
3734 
3735 void transform::TileUsingForallOp::build(OpBuilder &builder,
3736  OperationState &result, Value target,
3737  ArrayRef<OpFoldResult> mixedNumThreads,
3739  ArrayAttr mapping) {
3740  SmallVector<int64_t> staticNumThreads;
3741  SmallVector<Value> dynamicNumThreads;
3742  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3743  staticNumThreads);
3744  // Call the default builder which sets up the proper operands segment sizes
3745  // attributes for multiple variadic operands. In the absence of this,
3746  // horrible bugs ensue.
3747  MLIRContext *ctx = builder.getContext();
3748  auto operationType = transform::AnyOpType::get(ctx);
3749  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3750  build(builder, result,
3751  /*resultTypes=*/TypeRange{operationType, operationType},
3752  /*target=*/target,
3753  /*num_threads=*/dynamicNumThreads,
3754  /*tile_sizes=*/ValueRange{},
3755  /*packed_num_threads=*/Value(),
3756  /*packed_tile_sizes=*/Value(),
3757  /*static_num_threads=*/staticNumThreadsAttr,
3758  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3759  /*mapping=*/mapping);
3760 }
3761 
3762 /// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3763 /// normalized upper bound.
3767  ArrayRef<OpFoldResult> steps) {
3768  AffineExpr s0, s1, s2;
3769  bindSymbols(rewriter.getContext(), s0, s1, s2);
3770  AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3771  SmallVector<OpFoldResult> normalizedUbs;
3772  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3774  rewriter, loc, normalizedUbExpr, {lb, ub, step});
3775  normalizedUbs.push_back(normalizedUb);
3776  }
3777  return normalizedUbs;
3778 }
3779 
3780 /// When a loop is normalized, the uses of the induction variable within the
3781 /// loop need to replaced with `original_lb + old_iv * original_step`.
3783  Location loc, ValueRange ivs,
3785  ArrayRef<OpFoldResult> steps) {
3786  AffineExpr s0, s1;
3787  AffineExpr d0;
3788  bindSymbols(rewriter.getContext(), s0, s1);
3789  bindDims(rewriter.getContext(), d0);
3790  AffineExpr denormExpr = s0 + d0 * s1;
3791  SmallVector<Value> denormalizedIvs;
3792 
3793  for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3795  rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3796  denormalizedIvs.push_back(
3797  getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3798  }
3799  return denormalizedIvs;
3800 }
3801 
3802 /// Given a `scf.forall` loop return a loop op with the loop bounds
3803 /// normalized.
3804 /// TODO: Replace this with a general utility to normalize `scf.forall`.
3805 /// At the time of writing, this wasnt done since adding this to `scf`
3806 /// dialect would disallow using of `affine.apply` operations due
3807 /// to cyclic dependencies. To avoid churn in lit tests
3808 /// with the change this was added with, defer that to a follow up.
3809 static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3810  scf::ForallOp loop) {
3811  SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3812  SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3813  SmallVector<OpFoldResult> steps = loop.getMixedStep();
3814 
3815  if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
3816  return loop;
3817  }
3818 
3819  Location loc = loop.getLoc();
3820  SmallVector<OpFoldResult> normalizedUbs =
3821  normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3822  SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3823  rewriter.getIndexAttr(0));
3824  SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3825  rewriter.getIndexAttr(1));
3826 
3827  auto normalizedForallOp = scf::ForallOp::create(
3828  rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3829  loop.getOutputs(), loop.getMapping(),
3830  [](OpBuilder &, Location, ValueRange) {});
3831 
3832  auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3833  OpBuilder::InsertionGuard g(rewriter);
3834  Block *normalizedLoopBlock = normalizedForallOp.getBody();
3835  rewriter.setInsertionPointToStart(normalizedLoopBlock);
3836 
3837  SmallVector<Value> argValues =
3838  denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3839  argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3840  normalizedForallOp.getRegionIterArgs().end());
3841  Block *origLoopBlock = loop.getBody();
3842  rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3843 
3844  rewriter.replaceOp(loop, normalizedForallOp);
3845  return normalizedForallOp;
3846 }
3847 
3849  RewriterBase &rewriter, transform::TransformState &state,
3850  TransformOpInterface transformOp, Operation *target,
3851  ArrayRef<OpFoldResult> mixedNumThreads,
3852  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3853  scf::SCFTilingResult &tilingResult) {
3854  // Transform all targets one by one.
3855  auto tileableOp = dyn_cast<TilingInterface>(target);
3856  if (!tileableOp) {
3858  transformOp.emitSilenceableError()
3859  << "only TilingInterface ops are supported";
3860  diag.attachNote(target->getLoc()) << "target op";
3861  return diag;
3862  }
3863  rewriter.setInsertionPoint(tileableOp);
3866  if (!mixedNumThreads.empty()) {
3867  options.setNumThreads(mixedNumThreads);
3868  } else {
3869  options.setTileSizes(mixedTileSizes);
3870  }
3871  if (mapping) {
3872  options.setMapping(mapping.value().getValue());
3873  }
3874  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3875  scf::tileUsingSCF(rewriter, tileableOp, options);
3876 
3877  if (failed(maybeTilingResult))
3878  return transformOp.emitDefaultSilenceableFailure(tileableOp);
3879 
3880  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3881 
3882  tilingResult = *maybeTilingResult;
3883 
3884  if (mixedNumThreads.empty()) {
3885  auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3886  OpBuilder::InsertionGuard g(rewriter);
3887  rewriter.setInsertionPoint(generatedForallOp);
3888  scf::ForallOp normalizedForallOp =
3889  normalizeForallLoopOp(rewriter, generatedForallOp);
3890  tilingResult.loops.front() = normalizedForallOp;
3891  }
3892 
3894 }
3895 
3896 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3897  transform::TransformRewriter &rewriter,
3898  transform::TransformResults &transformResults,
3899  transform::TransformState &state) {
3900  auto transformOp = cast<TransformOpInterface>(getOperation());
3901 
3902  // Result payload ops.
3903  SmallVector<Operation *> tileOps;
3904  SmallVector<Operation *> tiledOps;
3905 
3906  // Unpack handles.
3907  SmallVector<OpFoldResult> mixedNumThreads;
3909  getPackedNumThreads()
3911  state, transformOp, mixedNumThreads, getPackedNumThreads())
3913  state, transformOp, mixedNumThreads, getMixedNumThreads());
3914  if (!status.succeeded())
3915  return status;
3916  SmallVector<OpFoldResult> mixedTileSizes;
3917  status = getPackedTileSizes()
3919  state, transformOp, mixedTileSizes, getPackedTileSizes())
3921  state, transformOp, mixedTileSizes, getMixedTileSizes());
3922  if (!status.succeeded())
3923  return status;
3924 
3925  for (Operation *target : state.getPayloadOps(getTarget())) {
3926  scf::SCFTilingResult tilingResult;
3928  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3929  getMapping(), tilingResult);
3930  if (!diag.succeeded())
3931  return diag;
3932  tileOps.push_back(tilingResult.loops.front());
3933  tiledOps.append(tilingResult.tiledOps);
3934  }
3935 
3936  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3937  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3938 
3940 }
3941 
3942 void transform::TileUsingForallOp::getEffects(
3944  consumesHandle(getTargetMutable(), effects);
3945  onlyReadsHandle(getTileSizesMutable(), effects);
3946  onlyReadsHandle(getNumThreadsMutable(), effects);
3947  onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3948  onlyReadsHandle(getPackedTileSizesMutable(), effects);
3949  producesHandle(getOperation()->getOpResults(), effects);
3950  modifiesPayload(effects);
3951 }
3952 
3953 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3954  Builder b(getContext());
3955  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3956 }
3957 
3958 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3959  Builder b(getContext());
3960  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3961 }
3962 
3963 LogicalResult TileUsingForallOp::verify() {
3964  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3965  static_cast<int>(getPackedNumThreads() != Value());
3966  if (numThreadsSpec > 1)
3967  return emitOpError(
3968  "num_threads and packed_num_threads are mutually exclusive");
3969  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3970  static_cast<int>(getPackedTileSizes() != Value());
3971  if (tileSizesSpec > 1)
3972  return emitOpError(
3973  "tile_sizes and packed_tile_sizes are mutually exclusive");
3974  if (numThreadsSpec == 0 && tileSizesSpec == 0)
3975  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3976  "must be specified");
3977  return success();
3978 }
3979 
3980 //===----------------------------------------------------------------------===//
3981 // VectorizeChildrenAndApplyPatternsOp
3982 //===----------------------------------------------------------------------===//
3983 
3984 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3985  OpBuilder &builder, OperationState &result, Value target,
3986  bool foldTypeExtensionsIntoContract, bool vectorizePadding,
3987  bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3988  result.addOperands(target);
3989  if (foldTypeExtensionsIntoContract) {
3990  result.addAttribute(
3991  VectorizeChildrenAndApplyPatternsOp::
3992  getFoldTypeExtensionsIntoContractAttrName(result.name),
3993  builder.getUnitAttr());
3994  }
3995  if (vectorizePadding) {
3996  result.addAttribute(
3997  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3998  result.name),
3999  builder.getUnitAttr());
4000  }
4001  if (vectorizeExtract) {
4002  result.addAttribute(
4003  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
4004  result.name),
4005  builder.getUnitAttr());
4006  }
4007  if (flatten1DDepthwiseConv) {
4008  result.addAttribute(
4009  VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
4010  result.name),
4011  builder.getUnitAttr());
4012  }
4013  result.addTypes(transform::AnyOpType::get(builder.getContext()));
4014 }
4015 
4016 namespace {
4017 /// This is an helper only to call vectorize via a pattern inside of
4018 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
4019 struct VectorizationPattern : public RewritePattern {
4020  explicit VectorizationPattern(MLIRContext *context,
4021  bool vectorizeExtract = false,
4022  bool flattenConv = false)
4023  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
4024  vectorizeNDExtract(vectorizeExtract),
4025  flatten1DDepthwiseConv(flattenConv) {}
4026  LogicalResult matchAndRewrite(Operation *op,
4027  PatternRewriter &rewriter) const override {
4029  return rewriter.notifyMatchFailure(op,
4030  "Unsupported Op, cannot vectorize");
4031  FailureOr<VectorizationResult> vectorResults =
4032  vectorize(rewriter, op, /*inputVectorSizes=*/{},
4033  /*inputScalableVecDims=*/{}, vectorizeNDExtract,
4034  flatten1DDepthwiseConv);
4035  if (failed(vectorResults))
4036  return failure();
4037  rewriter.replaceOp(op, vectorResults->replacements);
4038  return success();
4039  }
4040 
4041 private:
4042  /// Controls whether to vectorize `tensor.extract` when the input tensor is
4043  /// rank >= 2.
4044  bool vectorizeNDExtract = false;
4045  /// Controls whether to "flatten" the channel dimension when vectorising 1D
4046  /// depthwise convolutions. This should lead to bette vectorization for
4047  /// tensors with a low number of channel dimensions.
4048  bool flatten1DDepthwiseConv = false;
4049 };
4050 } // namespace
4051 
4053 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
4054  transform::TransformRewriter &rewriter, Operation *target,
4056  transform::TransformState &state) {
4057  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
4058  auto diag = this->emitOpError("requires isolated-from-above targets");
4059  diag.attachNote(target->getLoc()) << "non-isolated target";
4061  }
4062 
4063  MLIRContext *ctx = getContext();
4065  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
4066  getFlatten_1dDepthwiseConv());
4067 
4068  if (!getDisableTransferPermutationMapLoweringPatterns())
4070 
4071  if (!getDisableMultiReductionToContractPatterns())
4073 
4075 
4078  /*benefit=*/2);
4079  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
4080  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
4082 
4084 
4085  if (getFoldTypeExtensionsIntoContract())
4087 
4088  if (getVectorizePadding()) {
4090  // This creates an alternative path for lowering tensor.pad - by
4091  // decomposing it into e.g. linalg.fill.
4093  }
4095 
4096  TrackingListener listener(state, *this);
4097  if (failed(
4098  applyPatternsGreedily(target, std::move(patterns),
4099  GreedyRewriteConfig().setListener(&listener))))
4100  return emitDefaultDefiniteFailure(target);
4101 
4102  results.push_back(target);
4104 }
4105 
4106 //===----------------------------------------------------------------------===//
4107 // VectorizeOp
4108 //===----------------------------------------------------------------------===//
4109 
4110 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
4111  transform::TransformRewriter &rewriter,
4112  mlir::transform::TransformResults &transformResults,
4114  auto targets = state.getPayloadOps(getTarget());
4115  if (std::empty(targets))
4117  auto transformOp = cast<TransformOpInterface>(getOperation());
4118  SmallVector<int64_t> vectorSizes;
4120  state, transformOp, getMixedVectorSizes(), vectorSizes);
4121  if (!status.succeeded())
4122  return status;
4123 
4124  // TODO: Check that the correct number of vectorSizes was provided.
4125  for (Operation *target : targets) {
4126  if (!linalg::hasVectorizationImpl(target)) {
4127  return mlir::emitSilenceableFailure(target->getLoc())
4128  << "Unsupported Op, cannot vectorize";
4129  }
4130  FailureOr<VectorizationResult> vectorResults =
4131  linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
4132  getVectorizeNdExtract().value_or(false),
4133  /*flatten1DDepthwiseConv=*/false,
4134  getAssumeDynamicDimsMatchVecSizes().value_or(false),
4135  getCreateNamedContraction().value_or(false));
4136  if (failed(vectorResults)) {
4137  return mlir::emitSilenceableFailure(target->getLoc())
4138  << "Attempted to vectorize, but failed";
4139  }
4140  rewriter.replaceOp(target, vectorResults->replacements);
4141  }
4142 
4144 }
4145 
4146 void transform::VectorizeOp::getEffects(
4148  consumesHandle(getTargetMutable(), effects);
4149  onlyReadsHandle(getVectorSizesMutable(), effects);
4150  modifiesPayload(effects);
4151 }
4152 
4153 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
4154  OpBuilder b(getContext());
4155  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
4156 }
4157 
4158 LogicalResult transform::VectorizeOp::verify() {
4159  if (getStaticVectorSizes().size() != getScalableSizes().size())
4160  return emitOpError("expected same number of vector sizes (")
4161  << getStaticVectorSizes().size() << ") and scalable sizes ("
4162  << getScalableSizes().size() << ")";
4163  return success();
4164 }
4165 
4166 //===----------------------------------------------------------------------===//
4167 // HoistRedundantVectorTransfersOp
4168 //===----------------------------------------------------------------------===//
4169 
4171 transform::HoistRedundantVectorTransfersOp::applyToOne(
4172  transform::TransformRewriter &rewriter, func::FuncOp target,
4174  transform::TransformState &state) {
4175  // WARNING: This hoisting does not model parallelism and is generally
4176  // incorrect when used on distributed loops with memref semantics!
4177  // TODO: obsolete and should be retired.
4178  linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
4179  results.push_back(target);
4181 }
4182 
4183 //===----------------------------------------------------------------------===//
4184 // HoistRedundantVectorBroadcastsOp
4185 //===----------------------------------------------------------------------===//
4186 
4188 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4189  transform::TransformRewriter &rewriter, mlir::Operation *target,
4191  transform::TransformState &state) {
4192  rewriter.setInsertionPoint(target);
4193  linalg::hoistRedundantVectorBroadcasts(rewriter, target);
4194  results.push_back(target);
4196 }
4197 
4198 //===----------------------------------------------------------------------===//
4199 // ConvertConv2DToImg2ColOp.
4200 //===----------------------------------------------------------------------===//
4201 
4202 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
4203  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4205  transform::TransformState &state) {
4206  rewriter.setInsertionPoint(target);
4207  auto maybeTransformed =
4209  target)
4210  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4211  return rewriteInIm2Col(rewriter, op);
4212  })
4213  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4214  return rewriteInIm2Col(rewriter, op);
4215  })
4216  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4217  return rewriteInIm2Col(rewriter, op);
4218  })
4219  .Case([&](linalg::Conv2DNchwFchwOp op) {
4220  return rewriteInIm2Col(rewriter, op);
4221  })
4222  .Default([&](Operation *op) {
4223  return rewriter.notifyMatchFailure(op, "not supported");
4224  });
4225  if (failed(maybeTransformed))
4226  return emitDefaultSilenceableFailure(target);
4227  // Handle to the operation producing the img2col tensor.
4228  results.push_back(maybeTransformed->first);
4229  // Handle to the operation that replaces the original convolution.
4230  results.push_back(maybeTransformed->second);
4232 }
4233 
4234 //===----------------------------------------------------------------------===//
4235 // FlattenElementwiseLinalgOp.
4236 //===----------------------------------------------------------------------===//
4237 
4238 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4239  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4241  transform::TransformState &state) {
4242  rewriter.setInsertionPoint(target);
4243  if (!isElementwise(target))
4244  return mlir::emitSilenceableFailure(target->getLoc())
4245  << "only elementwise flattening is supported";
4246 
4247  // If rank <= 1, do nothing
4248  if (target.getNumLoops() <= 1) {
4249  results.push_back(target);
4251  }
4252 
4253  // Attempt to flatten all dims to one.
4254  ReassociationIndices reassociation(target.getNumLoops());
4255  std::iota(reassociation.begin(), reassociation.end(), 0);
4256  auto maybeFlattened =
4257  collapseOpIterationDims(target, reassociation, rewriter);
4258  if (failed(maybeFlattened))
4259  return mlir::emitSilenceableFailure(target->getLoc())
4260  << "attempted to flatten, but failed";
4261  results.push_back(maybeFlattened->collapsedOp);
4262  rewriter.replaceOp(target, maybeFlattened->results);
4264 }
4265 
4266 //===----------------------------------------------------------------------===//
4267 // TransposeConv2DOp
4268 //===----------------------------------------------------------------------===//
4269 
4270 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4271  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4273  transform::TransformState &state) {
4274  rewriter.setInsertionPoint(target);
4275  auto maybeTransformed =
4277  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4278  return transposeConv2D(rewriter, op);
4279  })
4280  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4281  return transposeConv2D(rewriter, op);
4282  })
4283  .Default([&](Operation *op) {
4284  return rewriter.notifyMatchFailure(op, "not supported");
4285  });
4286  if (failed(maybeTransformed))
4287  return emitDefaultSilenceableFailure(target);
4288  // Handle to the new Conv2D operation with transposed filters
4289  results.push_back(*maybeTransformed);
4291 }
4292 
4293 //===----------------------------------------------------------------------===//
4294 // TransposeMatmulOp
4295 //===----------------------------------------------------------------------===//
4296 
4297 DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4298  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4300  transform::TransformState &state) {
4301  rewriter.setInsertionPoint(target);
4302  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4303  auto maybeTransformed =
4305  .Case([&](linalg::MatmulOp op) {
4306  return transposeMatmul(rewriter, op, transposeLHS);
4307  })
4308  .Case([&](linalg::BatchMatmulOp op) {
4309  return transposeBatchMatmul(rewriter, op, transposeLHS);
4310  })
4311  .Default([&](Operation *op) { return failure(); });
4312  if (failed(maybeTransformed))
4313  return emitSilenceableFailure(target->getLoc()) << "not supported";
4314  // Handle to the new Matmul operation with transposed filters
4315  results.push_back(*maybeTransformed);
4317 }
4318 
4319 //===----------------------------------------------------------------------===//
4320 // InsertSliceToCopyOp
4321 //===----------------------------------------------------------------------===//
4322 template <typename OpTy>
4325  transform::TransformState &state) {
4326  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4327  tensor::ParallelInsertSliceOp>() &&
4328  "wrong op type");
4329 
4330  if (auto copySource =
4331  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4332  results.push_back(copySource);
4334  }
4335 
4336  // If we are inside a `ParallelCombiningOp` region, temporarily set the
4337  // insertion point outside: only ops implementing ParallelCombiningOpInterface
4338  // are allowed in there.
4339  if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
4340  rewriter.setInsertionPoint(target->getParentOp());
4341 
4342  Value extracted = tensor::ExtractSliceOp::create(
4343  rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
4344  target.getMixedSizes(), target.getMixedStrides());
4345  Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
4346  target.getSource(), extracted)
4347  .getResult(0);
4348  // Reset the insertion point.
4349  rewriter.setInsertionPoint(target);
4350  rewriter.replaceOpWithNewOp<OpTy>(
4351  target, copied, target.getDest(), target.getMixedOffsets(),
4352  target.getMixedSizes(), target.getMixedStrides());
4353 
4354  results.push_back(copied.getDefiningOp());
4356 }
4357 
4358 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4359  transform::TransformRewriter &rewriter, Operation *targetOp,
4361  transform::TransformState &state) {
4362 
4363  rewriter.setInsertionPoint(targetOp);
4364  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4365  return doit(rewriter, target, results, state);
4366  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4367  return doit(rewriter, target, results, state);
4368 
4370  emitSilenceableError()
4371  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4372  diag.attachNote(targetOp->getLoc()) << "target op";
4373  return diag;
4374 }
4375 
4376 //===----------------------------------------------------------------------===//
4377 // MapCopyToThreadsOp
4378 //===----------------------------------------------------------------------===//
4379 
4380 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4381  transform::TransformRewriter &rewriter, Operation *target,
4383  transform::TransformState &state) {
4384  // Check if the op is supported.
4385  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4387  emitSilenceableError()
4388  << "only linalg.copy and tensor.pad target ops are supported";
4389  diag.attachNote(target->getLoc()) << "target op";
4390  return diag;
4391  }
4392  assert(target->getNumResults() == 1 && "expected single result");
4393  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
4394  if (!resultShapedType.hasStaticShape()) {
4396  emitSilenceableError()
4397  << "only statically sized ops of rank <= 3 are supported";
4398  diag.attachNote(target->getLoc()) << "target op";
4399  return diag;
4400  }
4401 
4402  // Conservatively set the minimum viable desired bitwidth alignment.
4403  int64_t desiredBitAlignment = getDesiredBitAlignment();
4404  int64_t eltBitwidth =
4405  resultShapedType.getElementType().getIntOrFloatBitWidth();
4406  if (desiredBitAlignment % eltBitwidth != 0) {
4407  desiredBitAlignment = eltBitwidth;
4408  }
4409 
4410  gpu::CopyMappingInfo mapping(
4411  /*ctx=*/getContext(),
4412  /*totalNumThreads=*/getTotalNumThreads(),
4413  /*alignment=*/desiredBitAlignment,
4414  /*sizes=*/resultShapedType.getShape(),
4415  /*favorPredication=*/false,
4416  /*elementalBitwidth=*/
4417  resultShapedType.getElementType().getIntOrFloatBitWidth());
4418  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4420  emitSilenceableError()
4421  << "too few threads to map copy op to threads on the most minor "
4422  "dimension, given alignment and vector size constraints, try "
4423  "smaller tile size of mapping to more threads";
4424  diag.attachNote(target->getLoc()) << "target op";
4425  return diag;
4426  }
4427 
4428  // OpBuilder only used to compute attributes.
4429  OpBuilder b(getContext());
4430  scf::SCFTilingResult tilingResult;
4432  /*rewriter=*/rewriter,
4433  /*state=*/state,
4434  /*transformOp=*/*this,
4435  /*target=*/target,
4436  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
4437  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
4438  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
4439  /*tilingResult=*/tilingResult);
4440  if (!diag.succeeded())
4441  return diag;
4442 
4443  results.push_back(tilingResult.loops.front());
4444  for (auto op : tilingResult.tiledOps)
4445  results.push_back(op);
4447 }
4448 
4449 //===----------------------------------------------------------------------===//
4450 // WinogradConv2DOp
4451 //===----------------------------------------------------------------------===//
4452 
4453 DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4454  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4456  transform::TransformState &state) {
4457  rewriter.setInsertionPoint(target);
4458  FailureOr<Operation *> maybeTransformed = failure();
4459  bool supported = TypeSwitch<Operation *, bool>(target)
4460  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4461  maybeTransformed =
4462  winogradConv2D(rewriter, op, getFmr());
4463  return true;
4464  })
4465  .Default([&](Operation *op) { return false; });
4466 
4467  if (!supported) {
4468  return emitSilenceableError()
4469  << "this operation is not supported to convert to Winograd Conv2D";
4470  }
4471 
4472  if (failed(maybeTransformed)) {
4473  return emitSilenceableError() << "apply Winograd Conv2D failed";
4474  }
4475 
4476  results.push_back(*maybeTransformed);
4478 }
4479 
4480 DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4481  transform::TransformRewriter &rewriter, Operation *target,
4483  transform::TransformState &state) {
4484  rewriter.setInsertionPoint(target);
4485  FailureOr<Operation *> maybeTransformed = failure();
4486  bool supported =
4488  .Case([&](linalg::WinogradFilterTransformOp op) {
4489  maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
4490  return true;
4491  })
4492  .Case([&](linalg::WinogradInputTransformOp op) {
4493  maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
4494  return true;
4495  })
4496  .Case([&](linalg::WinogradOutputTransformOp op) {
4497  maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
4498  return true;
4499  })
4500  .Default([&](Operation *op) { return false; });
4501 
4502  if (!supported) {
4504  emitSilenceableError()
4505  << "this operation is not supported to decompose into other operations";
4506  diag.attachNote(target->getLoc()) << "target op";
4507  return diag;
4508  }
4509 
4510  if (failed(maybeTransformed)) {
4512  emitSilenceableError() << "decompose Winograd operations failed";
4513  diag.attachNote(target->getLoc()) << "target op";
4514  return diag;
4515  }
4516 
4517  results.push_back(*maybeTransformed);
4519 }
4520 
4521 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4522 
4523 #define GET_OP_CLASSES
4524 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Definition: AMXDialect.cpp:70
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static bool mayBeRead(OpOperand &operand)
Return true if the operand may be read from by its owner.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static bool sameOrEquivalentIterArg(Value src, Value dst)
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
UnitAttr getUnitAttr()
Definition: Builders.cpp:98
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:167
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:368
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:112
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:266
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:281
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:306
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
This class represents a saved insertion point.
Definition: Builders.h:327
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:337
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:562
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:316
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getOpResult(unsigned idx)
Definition: Operation.h:421
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:710
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1276
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1374
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
Definition: Padding.cpp:244
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:471
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:657
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:245
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:346
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:857
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:512
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:496
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:487
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:400
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:504
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:220
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:45
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:236
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition: Tiling.cpp:156
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition: Tiling.cpp:106
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:89
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:748
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:464
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:422
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:491
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:480
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:198
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:262
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:217
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:651
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:61
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:114
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:285
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:431
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:432
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:522
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1626
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1558
Match and rewrite for the pattern:
Definition: Transforms.h:1755
Match and rewrite for the pattern:
Definition: Transforms.h:1783
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:421
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:434
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:447
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:467
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:410
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:441
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:457
LinalgPromotionOptions & setUseOriginalSubviewSize(bool originalSize)
Definition: Transforms.h:428
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:399
Split Reduction options.
Definition: Transforms.h:476
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
SCFTilingOptions & setLoopType(LoopType type)
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.