MLIR  22.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/TypeUtilities.h"
41 #include "mlir/Support/LLVM.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/DebugLog.h"
47 #include "llvm/Support/LogicalResult.h"
48 #include <type_traits>
49 
50 using namespace mlir;
51 using namespace mlir::linalg;
52 using namespace mlir::transform;
53 
54 #define DEBUG_TYPE "linalg-transforms"
55 
56 /// Attempts to apply the pattern specified as template argument to the given
57 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
58 /// function that returns the "main" result or failure. Returns failure if the
59 /// pattern failed to apply. Extra arguments are forwarded to the pattern
60 /// constructor.
61 template <typename PatternTy, typename... Args>
62 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
63  // Check if the given operation has the type expected by the pattern.
64  using OpTy = typename llvm::function_traits<
65  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
66  auto op = dyn_cast<OpTy>(operation);
67  if (!op)
68  return failure();
69 
70  // Apply the pattern directly to the op.
71  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
72  // We want to discourage direct use of PatternRewriter in APIs but In this
73  // very specific case, an IRRewriter is not enough.
74  PatternRewriter rewriter(operation->getContext());
75  rewriter.setInsertionPoint(operation);
76  auto result = pattern.returningMatchAndRewrite(op, rewriter);
77  if (failed(result))
78  return failure();
79  return cast<LinalgOp>(result->getOperation());
80 }
81 
82 /// Assuming that `ofr` is an index attr or a param of index type
83 /// or a transform dialect handle mapped to exactly one op
84 /// with one index result, return that value.
86  transform::TransformState &state, TransformOpInterface transformOp,
88  for (OpFoldResult ofr : ofrs) {
89  if (auto attr = dyn_cast<Attribute>(ofr)) {
90  if (!isa<IntegerAttr>(attr))
91  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
92  result.push_back(ofr);
93  continue;
94  }
95 
96  Value transformValue = cast<Value>(ofr);
97  if (isa<TransformParamTypeInterface>(transformValue.getType())) {
98  ArrayRef<Attribute> params = state.getParams(transformValue);
99  if (params.size() != 1)
100  return transformOp.emitDefiniteFailure()
101  << "requires exactly one parameter associated";
102  result.push_back(params[0]);
103  continue;
104  }
105 
106  auto payloadOps = state.getPayloadOps(transformValue);
107  if (!llvm::hasSingleElement(payloadOps)) {
109  transformOp.emitSilenceableError()
110  << "handle must be mapped to exactly one payload op";
111  diag.attachNote(transformValue.getLoc())
112  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
113  return diag;
114  }
115 
116  Operation *op = *payloadOps.begin();
117  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
119  transformOp.emitSilenceableError()
120  << "payload op must have exactly 1 index result";
121  diag.attachNote(op->getLoc())
122  << "has " << op->getNumResults() << " results";
123  return diag;
124  }
125  result.push_back(op->getResult(0));
126  }
127 
129 }
130 
131 // Given a list of params that are index attrs or a list of OpFoldResults
132 // that are either index attrs or op handles, return a list of OpFoldResults
133 // of index attrs or a list of OpFoldResults where all op handles are
134 // replaced with the first (and only) OpResult of that payload op.
135 // (There must be exactly one parameter associated with the AnyParamType or
136 // one mapped payload op which must have exactly one index result.)
138  transform::TransformState &state, TransformOpInterface transformOp,
139  SmallVector<OpFoldResult> &result, Value packedHandle) {
140  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
141  ArrayRef<Attribute> params = state.getParams(packedHandle);
142  for (auto param : params) {
143  if (!isa<IntegerAttr>(param))
144  return transformOp.emitDefiniteFailure()
145  << "expected the parameter to be associated with an integer "
146  "attribute";
147  result.push_back(param);
148  }
150  }
151 
152  for (Operation *op : state.getPayloadOps(packedHandle)) {
153  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
155  transformOp.emitSilenceableError()
156  << "payload op must have exactly 1 index result";
157  diag.attachNote(op->getLoc())
158  << "has " << op->getNumResults() << " results";
159  return diag;
160  }
161  result.push_back(op->getResult(0));
162  }
163 
165 }
166 
167 /// When possible, converts each `OpFoldResult` in `mixedResult` to
168 /// an integer if the value can be statically inferred. If a result
169 /// is a `Value` then it must be either a `ParamType` or a handle
170 /// to an a constant like op.
172  TransformState &state, TransformOpInterface &transformOp,
173  ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
174  for (OpFoldResult paramOrHandle : mixedResults) {
175  if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
176  reified.push_back(cast<IntegerAttr>(attr).getInt());
177  continue;
178  } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
179  ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
180  if (params.size() != 1)
181  return transformOp.emitSilenceableError() << "expected a single param";
182  reified.push_back(
183  cast<IntegerAttr>(params.front()).getValue().getSExtValue());
184  continue;
185  }
186 
187  Value handle = cast<Value>(paramOrHandle);
188  if (!isa<TransformHandleTypeInterface>(handle.getType()))
189  return transformOp.emitSilenceableError() << "unexpected value handle";
190  auto payload = state.getPayloadOps(handle);
191  if (!llvm::hasSingleElement(payload))
192  return transformOp.emitSilenceableError()
193  << "requires param or handle that is mapped to 1 payload op";
194 
195  Operation *paramOrHandlePayloadOp = *payload.begin();
196  if (paramOrHandlePayloadOp->getNumResults() != 1 ||
197  !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
198  return transformOp.emitSilenceableError()
199  << "requires param or handle to be result of op with 1 index "
200  "result";
201  }
202 
203  IntegerAttr attr;
204  if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
205  return transformOp.emitSilenceableError()
206  << "requires param or handle to be the result of a constant like "
207  "op";
208 
209  reified.push_back(attr.getInt());
210  }
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // Apply...PatternsOp
216 //===----------------------------------------------------------------------===//
217 
218 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
221 }
222 
223 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
226 }
227 
228 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
231 }
232 
233 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
237 }
238 
239 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
242  options.rankReductionStrategy =
245 }
246 
247 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
250 }
251 
252 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
255 }
256 
257 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
260 }
261 
262 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
265 }
266 
267 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // BufferizeToAllocationOp
274 //===----------------------------------------------------------------------===//
275 
276 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
277  OperationState &result,
278  Value target,
279  Attribute memorySpace) {
280  SmallVector<Type> resultTypes;
281  resultTypes.push_back(b.getType<transform::AnyValueType>());
282  resultTypes.push_back(b.getType<transform::AnyOpType>());
283  return build(b, result,
284  /*resultTypes=*/resultTypes,
285  /*target=*/target,
286  /*memory_space=*/memorySpace);
287 }
288 
289 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
290  OperationState &result,
291  Value target,
292  int64_t memorySpace) {
293  SmallVector<Type> resultTypes;
294  resultTypes.push_back(b.getType<transform::AnyValueType>());
295  resultTypes.push_back(b.getType<transform::AnyOpType>());
296  return build(b, result,
297  /*resultTypes=*/resultTypes,
298  /*target=*/target,
299  /*memory_space=*/b.getI64IntegerAttr(memorySpace));
300 }
301 
302 namespace {
303 class NewOpsListener : public RewriterBase::ForwardingListener {
304 public:
306 
307  SmallVector<Operation *> getNewOps() const {
308  return SmallVector<Operation *>(newOps.begin(), newOps.end());
309  }
310 
311 private:
312  void notifyOperationInserted(Operation *op,
313  OpBuilder::InsertPoint previous) override {
314  ForwardingListener::notifyOperationInserted(op, previous);
315  // We only care about newly created ops.
316  if (previous.isSet())
317  return;
318  auto inserted = newOps.insert(op);
319  (void)inserted;
320  assert(inserted.second && "expected newly created op");
321  }
322 
323  void notifyOperationErased(Operation *op) override {
324  ForwardingListener::notifyOperationErased(op);
325  op->walk([&](Operation *op) { newOps.erase(op); });
326  }
327 
328  DenseSet<Operation *> newOps;
329 };
330 } // namespace
331 
332 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
335  // Attach listener to keep track of newly created ops.
336  OpBuilder::Listener *previousListener = rewriter.getListener();
337  auto resetListener =
338  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
339  NewOpsListener newOpsListener(previousListener);
340  rewriter.setListener(&newOpsListener);
341 
343  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
346  } else if (getMemcpyOp() == "memref.copy") {
347  options.memcpyOp =
349  } else if (getMemcpyOp() == "linalg.copy") {
350  options.memcpyOp =
352  } else {
353  llvm_unreachable("invalid memcpy op");
354  }
355  if (getAllocOp() == "memref.alloc") {
356  options.allocOp =
358  } else if (getAllocOp() == "memref.alloca") {
359  options.allocOp =
361  } else {
362  llvm_unreachable("invalid alloc op");
363  }
364  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
365  options.emitDealloc = getEmitDealloc();
366 
367  // Bufferize ops.
368  Attribute memorySpace =
369  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
370  SmallVector<Value> allocatedBuffers;
371  for (Operation *op : state.getPayloadOps(getTarget())) {
372  Value buffer =
373  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
374  if (!buffer) {
375  DiagnosedSilenceableFailure diag = emitSilenceableError()
376  << "failed to bufferize operation";
377  diag.attachNote(op->getLoc()) << "target payload op";
378  return diag;
379  }
380  allocatedBuffers.push_back(buffer);
381  }
382 
383  // Set results.
384  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
385  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
387 }
388 
389 void transform::BufferizeToAllocationOp::getEffects(
391  if (getBufferizeDestinationOnly()) {
392  // The destination is replaced with a newly allocated buffer, but the op
393  // itself remains in place.
394  onlyReadsHandle(getTargetMutable(), effects);
395  } else {
396  consumesHandle(getTargetMutable(), effects);
397  }
398  producesHandle(getOperation()->getOpResults(), effects);
399  modifiesPayload(effects);
400 }
401 
403  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
404  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
405  return emitOpError() << "unsupported memcpy op";
406  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
407  return emitOpError() << "unsupported alloc op";
408  return success();
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // DecomposeOp
413 //===----------------------------------------------------------------------===//
414 
416 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
417  LinalgOp target,
419  transform::TransformState &state) {
420 #define DOWNSCALE(trans) \
421  { \
422  FailureOr<LinalgOp> res = tryApply<trans>(target); \
423  if (succeeded(res)) { \
424  results.push_back(*res); \
425  return DiagnosedSilenceableFailure::success(); \
426  } \
427  }
428 
429 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
430 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
431 
432  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
433  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
434  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
435  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
436  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
437  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
438  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
439  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
440  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
443 #undef DOWNSCALE_NORMAL
444 #undef DOWNSCALE_CALL
445 #undef DOWNSCALE
446  return emitDefaultSilenceableFailure(target);
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // DecomposeInterfaceOp
451 //===----------------------------------------------------------------------===//
452 
453 // Decompose the target operation if it implements the AggregatedOpInterface.
454 // Push the decomposed operations (the ones that replaces the values produced by
455 // \p target) in the `results`.
456 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
457  transform::TransformRewriter &rewriter, Operation *target,
459  transform::TransformState &state) {
460  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
461  if (!decomposableOp) {
462  failed(rewriter.notifyMatchFailure(target,
463  "payload is not a decomposable op"));
464  return emitDefaultSilenceableFailure(target);
465  }
466 
467  FailureOr<SmallVector<Value>> maybeNewResults =
468  decomposableOp.decomposeOperation(rewriter);
469  if (failed(maybeNewResults))
470  return emitDefaultSilenceableFailure(target);
471 
472  rewriter.replaceOp(decomposableOp, *maybeNewResults);
473  for (Value val : *maybeNewResults) {
474  Operation *definition = val.getDefiningOp();
475  if (definition)
476  results.push_back(definition);
477  }
479 }
480 
481 //===----------------------------------------------------------------------===//
482 // EliminateLinalgOpAnchoredEmptyTensorsOp
483 //===----------------------------------------------------------------------===//
484 
485 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
487  onlyReadsHandle(getTargetMutable(), effects);
488  modifiesPayload(effects);
489 }
490 
492 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
493  transform::TransformRewriter &rewriter, TransformResults &transformResults,
494  TransformState &state) {
496  options.allowReturnAllocsFromLoops = true;
497 
498  for (Operation *target : state.getPayloadOps(getTarget())) {
500  if (failed(analyzeOp(target, state)))
501  return mlir::emitSilenceableFailure(target->getLoc())
502  << "failed to analyze op";
504  rewriter, target, state)))
505  return mlir::emitSilenceableFailure(target->getLoc())
506  << "failed to eliminate LinalgOp anchored tensor.empty ops";
507  }
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // FuseOp
513 //===----------------------------------------------------------------------===//
514 
515 /// Apply a tiling transformation to all payload ops and store both the
516 /// tiled operation as well as the created tile loops.
517 template <typename Range>
518 static LogicalResult applyTilingToAll(
519  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
520  unsigned numLoops, transform::TransformResults &transformResults,
521  function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
522  applyFn) {
523  SmallVector<Operation *> tiledLinalgOps;
524  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
525 
526  for (Operation *target : payloadOps) {
527  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
528  if (!tilingInterfaceOp)
529  return transformOp->emitError("only TilingInterface ops are supported");
530 
531  rewriter.setInsertionPoint(target);
532  FailureOr<scf::SCFTileAndFuseResult> tiledResults =
533  applyFn(tilingInterfaceOp);
534  if (failed(tiledResults))
535  return failure();
536 
537  // Perform the replacement of tiled and fused values.
538  SmallVector<Operation *> opsToReplace{target};
539  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
540  for (Operation *toReplace : opsToReplace) {
541  for (OpResult res : toReplace->getResults())
542  if (auto replacement = tiledResults->replacements.lookup(res))
543  rewriter.replaceAllUsesWith(res, replacement);
544  if (toReplace->use_empty()) {
545  rewriter.eraseOp(toReplace);
546  }
547  }
548 
549  // Report back the relevant handles to the transform op.
550  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
551  assert(tiledResults->loops.size() == numLoops &&
552  "Mismatched number of loops, tile and fuse transform should have "
553  "failed");
554  for (unsigned int i = 0; i < numLoops; ++i)
555  loopOps[i].push_back(tiledResults->loops[i]);
556  }
557 
558  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
559  for (unsigned int i = 0; i < numLoops; ++i)
560  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
561 
562  return success();
563 }
564 
566 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
567  mlir::transform::TransformResults &transformResults,
569  SmallVector<int64_t> tileSizes =
570  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
571  SmallVector<int64_t> tileInterchange =
572  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
573 
574  scf::SCFTilingOptions tilingOptions;
575  tilingOptions.interchangeVector = tileInterchange;
576  SmallVector<OpFoldResult> tileSizesOfr =
577  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
578  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
579  scf::SCFTileAndFuseOptions tileAndFuseOptions;
580  tileAndFuseOptions.tilingOptions = tilingOptions;
581 
582  if (getApplyCleanup()) {
583  MLIRContext *context = rewriter.getContext();
584  RewritePatternSet patterns(context);
585  tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
588  tileAndFuseOptions.cleanupPatterns = std::move(patterns);
589  }
590 
591  LogicalResult result = applyTilingToAll(
592  rewriter, getOperation(), state.getPayloadOps(getTarget()),
593  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
594  [&](TilingInterface tilingInterfaceOp)
595  -> FailureOr<scf::SCFTileAndFuseResult> {
596  return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
597  tileAndFuseOptions);
598  });
600  : DiagnosedSilenceableFailure::success();
601 }
602 
603 LogicalResult transform::FuseOp::verify() {
604  SmallVector<int64_t> permutation =
605  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
606  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
607  if (!std::is_permutation(sequence.begin(), sequence.end(),
608  permutation.begin(), permutation.end())) {
609  return emitOpError() << "expects interchange to be a permutation, found "
610  << getTileInterchange();
611  }
612 
613  SmallVector<int64_t> sizes =
614  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
615  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
616  if (numExpectedLoops != getNumResults() - 1)
617  return emitOpError() << "expects " << numExpectedLoops << " loop results";
618 
619  return success();
620 }
621 
622 //===----------------------------------------------------------------------===//
623 // FuseIntoContainingOp
624 //===----------------------------------------------------------------------===//
625 
626 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
627  OperationState &result,
628  Value producerOp,
629  Value containingOp) {
630  result.addOperands({producerOp, containingOp});
631  auto resultType = transform::AnyOpType::get(builder.getContext());
632  result.addTypes({resultType, resultType});
633 }
634 
635 /// Add new operands to the forall op for users of the producerOp
636 /// that are dominated by the containing scf.forall op.
638  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
639  Operation *containingOp, TilingResult &tileAndFuseResult,
640  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
641  SmallVector<OpFoldResult> &sizes) {
642 
643  // Count number of users not including the containing op
644  SetVector<Operation *> dominatedUsers;
645  DominanceInfo domInfo(containingOp);
646  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
647  if (!containingOp->isAncestor(user) &&
648  (domInfo.dominates(containingOp, user))) {
649  dominatedUsers.insert(user);
650  }
651  }
652  if (dominatedUsers.empty())
653  return nullptr;
654 
655  // Create new scf.forall op
656  auto forallOp = cast<scf::ForallOp>(containingOp);
657  OpBuilder::InsertionGuard g(rewriter);
658  rewriter.setInsertionPoint(forallOp);
659 
660  // Get new output
661  Location loc = forallOp.getLoc();
662  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
663  if (!genericOp)
664  return nullptr;
665  SmallVector<Value> outputs = genericOp.getOutputs();
666  SmallVector<Value> newOuts(forallOp.getOutputs());
667  newOuts.push_back(outputs[resultNumber]);
668 
669  // Create new scf.forall op
670  auto newforallOp = scf::ForallOp::create(
671  rewriter, loc, forallOp.getMixedLowerBound(),
672  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
673  forallOp.getMapping());
674  rewriter.eraseBlock(newforallOp.getBody());
675  newforallOp.getRegion().takeBody(forallOp.getRegion());
676 
677  // Add additional block argument for new value being returned
678  // and replaces all uses of the new output with corresponding bbArg
679  // inside the scf.forall to enable fusion into this new scf.forall.
680  newforallOp.getBody()->addArgument(newOuts.back().getType(),
681  newOuts.back().getLoc());
682  auto bbArgs = newforallOp.getBody()->getArguments();
683  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
684  [&](OpOperand &use) {
685  Operation *op = use.getOwner();
686  return newforallOp->isProperAncestor(op);
687  });
688 
689  // Fix terminator
690  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
691  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
692  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
693  Operation *firstYieldOp = yieldingOps.front();
694  rewriter.setInsertionPoint(firstYieldOp);
695  Value src = tileAndFuseResult.tiledValues[0];
696  Value dst = newforallOp.getRegionIterArgs().back();
697  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
698  tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->getLoc(), src,
699  dst, offsets, sizes, strides);
700 
701  for (auto result : llvm::enumerate(forallOp.getResults())) {
702  rewriter.replaceAllUsesWith(result.value(),
703  newforallOp->getResult(result.index()));
704  }
705  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
706  newforallOp->getResults().back(),
707  [&](OpOperand &use) {
708  Operation *user = use.getOwner();
709  return dominatedUsers.contains(user);
710  });
711  return newforallOp;
712 }
713 
714 /// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
715 /// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
716 /// outer loop. To determine the second condition, this function iterates
717 /// using a worklist over the enclosing loops, trying to find 'src' in any of
718 /// the parent loop's iter args.
719 static bool sameOrEquivalentIterArg(Value src, Value dst) {
720  // Stack like vector containing possible iterArgs candidates. The first one
721  // is dst, and we will transverse the IR from there.
722  SmallVector<Value> destWorklist;
723  destWorklist.push_back(dst);
724 
725  while (!destWorklist.empty()) {
726  Value currentDst = destWorklist.pop_back_val();
727 
728  // We have found the same operand in some iter arg in the loop structure,
729  // so src and dst are equivalent.
730  if (src == currentDst)
731  return true;
732 
733  // The operands are not equivalent, look for enclosing loops over
734  // currentDst.
735  auto bbArg = dyn_cast<BlockArgument>(currentDst);
736  if (!bbArg)
737  continue;
738 
739  Block *parentBlock = bbArg.getOwner();
740  assert(parentBlock && "unlinked block argument");
741 
742  Operation *parentOp = parentBlock->getParentOp();
743  assert(parentOp && "expected block argument with parent operation");
744 
745  // Check if parent is loop-like. If it's not, do not add it to the worklist.
746  auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
747  if (!parentLoop)
748  continue;
749 
750  for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
751  // No need to check for null as innerIterArg is tied to parentLoop.
752  OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
753  Value loopBlockArgument =
754  parentLoop->getOperand(operand->getOperandNumber());
755  destWorklist.push_back(loopBlockArgument);
756  }
757  }
758 
759  return false;
760 }
761 
762 /// Find the first "extract" user of `producerOp` and tile it right before its
763 /// use. The tiled op is fused under the `containingOp`.
764 /// Return this fused op on success or nullptr if anything fails.
765 /// If tiled op has uses that are dominated by `containingOp`, return
766 /// a new `containingOp` with results of the fused op appended to
767 /// results of the `containingOp` or nullptr if there are no dominated uses.
768 static std::tuple<SmallVector<Operation *>, Operation *>
770  Operation *producerOp, Operation *containingOp) {
771  LDBG() << "Try to fuse a direct extract use";
772  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
773  if (!tileableProducer) {
774  diag.attachNote(producerOp->getLoc())
775  << "producer is not a TileableInterface: " << *producerOp;
776  return {};
777  }
778 
779  // Search the producer slices accessed within the containing operation.
780  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
781  // evolve into an interface.
782  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
783  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
784  return sliceOp && containingOp->isProperAncestor(sliceOp);
785  });
786 
787  // Find a fusion opportunity.
788  if (it == tileableProducer->getUsers().end()) {
789  diag.attachNote(tileableProducer->getLoc())
790  << "could not find fusion opportunity for: " << *tileableProducer;
791  return {};
792  }
793  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
794 
795  // Try to fuse the producer in-place.
796  OpBuilder::InsertionGuard guard(rewriter);
797  rewriter.setInsertionPoint(sliceOpToTile);
798 
799  // Clone the producer inside the consumer and try to update the producer init
800  // operands using the loop bbArgs if applicable. More precisely, if the bbArg
801  // of the container loop points to a value that it is used by the consumer op,
802  // then, instead of using such value on the consumer, use the value coming
803  // from the bbArg instead. This allows to reuse the output tensor (instead of
804  // creating a new one) of the container when both producer and container write
805  // to the same output.
806  if (LoopLikeOpInterface containerLoop =
807  dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
808  Operation *clone = rewriter.clone(*producerOp);
809  rewriter.modifyOpInPlace(clone, [&]() {
810  // Iterate over the outputs of the producer and over the loop bbArgs and
811  // check if any bbArg points to the same value as the producer output. In
812  // such case, make the producer output point to the bbArg directly.
813  for (OpOperand &initOperandPtr :
814  cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
815  Value producerOperand =
816  clone->getOperand(initOperandPtr.getOperandNumber());
817  for (BlockArgument containerIterArg :
818  containerLoop.getRegionIterArgs()) {
819  OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
820  Value consumerOperand =
821  containerLoop->getOperand(bbArg->getOperandNumber());
822  // The producer has the same init as the loop bbArg, use it.
823  if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
824  initOperandPtr.set(containerIterArg);
825  }
826  }
827  }
828  });
829 
830  tileableProducer = dyn_cast<TilingInterface>(clone);
831  }
832 
833  // Tile the producer.
834  int64_t resultNumber =
835  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
836  LDBG() << "resultNumber: " << resultNumber;
837 
838  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
839  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
840 
841  FailureOr<TilingResult> tileAndFuseResult =
842  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
843  sizes);
844 
845  if (failed(tileAndFuseResult)) {
846  diag.attachNote(tileableProducer->getLoc())
847  << "failed to tile producer op: " << *tileableProducer;
848  return {};
849  }
850 
851 #ifndef NDEBUG
852  for (auto *tiledOp : tileAndFuseResult->tiledOps) {
853  LDBG() << "tiledProducer: " << *tiledOp;
854  }
855 #endif
856 
857  // Replace the extract op.
858  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
859  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
860  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
861  if (failed(maybeRankReduced)) {
862  diag.attachNote(producerOp->getLoc())
863  << "shape types don't match (missing canonicalization?):\nTiledOp: "
864  << tileAndFuseResult->tiledValues[0]
865  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
866  return {};
867  }
868  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
869 
870  // Add new outputs to containing op, if required
871  Operation *newContainingOp = replaceForAllWithNewSignature(
872  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
873  resultNumber, offsets, sizes);
874 
875  // Cleanup clone.
876  if (dyn_cast<LoopLikeOpInterface>(containingOp))
877  rewriter.eraseOp(tileableProducer);
878 
879  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
880 }
881 
882 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
883 /// it is exactly the `containingOp`, otherwise bail.
884 /// Then, find the first "extract" user of the tied block argument and tile it
885 /// right before its "extract" use. The tiled op is fused under the
886 /// `containingOp`.
887 /// Return this fused op on success or nullptr if anything fails.
890  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
891  Operation *containingOp) {
892  LDBG() << "Try to fuse an extract use through block argument";
893 
894  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
895  if (!tileableProducer) {
896  diag.attachNote(producerOp->getLoc())
897  << "producer is not a TileableInterface: " << *producerOp;
898  return {};
899  }
900 
901  // Search the first use by a "scf::ForallOp" user.
902  scf::ForallOp forallOp;
903  auto itProducerUses =
904  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
905  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
906  return forallOp;
907  });
908  // If it's not from the containing op, return.
909  if (!forallOp || forallOp != containingOp) {
910  diag.attachNote(tileableProducer->getLoc())
911  << "could not find a use by the containing op: " << *tileableProducer;
912  return {};
913  }
914 
915  // Search the producer slices accessed within the containing
916  // operation.
917  // TODO: Generalize to more extract/insert/parallel_insert triples.
918  // Maybe evolve into an interface.
919  OpOperand *pUse = &(*itProducerUses);
920  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
921 
922  // Search the producer slices accessed within the containing operation.
923  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
924  // evolve into an interface.
925  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
926  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
927  return sliceOp && containingOp->isProperAncestor(sliceOp);
928  });
929 
930  // Find a fusion opportunity.
931  if (itBBArgUsers == bbArg.getUsers().end()) {
932  diag.attachNote(containingOp->getLoc())
933  << "could not find fusion opportunity for bbArg: " << bbArg;
934  return {};
935  }
936  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
937 
938  // Try to fuse the producer in-place.
939  OpBuilder::InsertionGuard guard(rewriter);
940  rewriter.setInsertionPoint(sliceOpToTile);
941 
942  // Replace the use in the tileableProducer before tiling: clone, replace and
943  // then tile.
944  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
945  LDBG() << "resultNumber: " << resultNumber;
946 
947  // Gather destination tensors.
948  SmallVector<Value> destinationTensors;
950  rewriter, tileableProducer->getLoc(), tileableProducer,
951  destinationTensors))) {
952  diag.attachNote(tileableProducer->getLoc())
953  << "failed to get destination tensors for: " << *tileableProducer;
954  return {};
955  }
956 
957  IRMapping bvm;
958  bvm.map(destinationTensors[resultNumber], bbArg);
959  auto tileableProducerClone =
960  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
961  auto scopeGuard =
962  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
963 
964  // Tile the producer.
965  FailureOr<TilingResult> tileAndFuseResult =
966  tileableProducerClone.generateResultTileValue(
967  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
968  sliceOpToTile.getMixedSizes());
969  if (failed(tileAndFuseResult)) {
970  diag.attachNote(tileableProducer->getLoc())
971  << "failed to tile producer op: " << *tileableProducer;
972  return {};
973  }
974 
975  // Replace the extract op.
976  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
977  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
978  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
979  assert(succeeded(maybeRankReduced) && "unexpected shape");
980  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
981 
982  // Replace the use in containingOp.
983  rewriter.modifyOpInPlace(containingOp, [&]() {
984  containingOp->setOperand(pUse->getOperandNumber(),
985  destinationTensors.front());
986  });
987 
988  return tileAndFuseResult->tiledOps;
989 }
990 
992  Operation *producerOp,
993  Operation *containingOp) {
994  LDBG() << "Try to fuse an use by cloning";
995 
996  // Gather all uses inside the containing op.
998  for (OpResult result : producerOp->getOpResults()) {
999  for (OpOperand &use : result.getUses()) {
1000  if (containingOp->isProperAncestor(use.getOwner())) {
1001  uses.push_back(&use);
1002  continue;
1003  }
1004  // Cannot clone and fuse if the use is by the containing op itself: fail
1005  // immediately.
1006  if (containingOp == use.getOwner()) {
1007  diag.attachNote(producerOp->getLoc())
1008  << "producer op use by containing op cannot be fused by cloning";
1009  return nullptr;
1010  }
1011  }
1012  }
1013 
1014  // Check for a non-empty list of fusion opportunities.
1015  if (uses.empty()) {
1016  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
1017  return nullptr;
1018  }
1019 
1020  // Clone and fuse inside the containing op.
1021  Operation *fusedOp = nullptr;
1022  OpOperand *use = uses.front();
1023  // Parallel insert slice is not a valid clone destination.
1024  // TODO: Generalize to other type of ops.
1025  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
1026  "Parallel insert slice is not a valid clone destination");
1027  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
1028  LDBG() << "resultNumber: " << resultNumber;
1029 
1030  OpBuilder::InsertionGuard guard(rewriter);
1031  rewriter.setInsertionPoint(use->getOwner());
1032  fusedOp = rewriter.clone(*producerOp);
1033  rewriter.modifyOpInPlace(
1034  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1035 
1036  return fusedOp;
1037 }
1038 
1039 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1040  // Allow repeated handles since we are fusing everything anyway.
1041  return true;
1042 }
1043 
1045 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
1046  transform::TransformResults &results,
1047  transform::TransformState &state) {
1048  SmallVector<Operation *> fusedOps;
1049  auto producerOps = state.getPayloadOps(getProducerOp());
1050  auto containingOps = state.getPayloadOps(getContainingOp());
1051  if (!llvm::hasSingleElement(containingOps)) {
1052  return emitDefiniteFailure()
1053  << "requires exactly one containing_op handle (got "
1054  << llvm::range_size(containingOps) << ")";
1055  }
1056  Operation *containingOp = *containingOps.begin();
1057 
1058  // If nothing to fuse, propagate success.
1059  if (std::empty(producerOps)) {
1060  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
1061  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1063  }
1064 
1065  // Helper function to find the next producer that should be fused. Take any
1066  // producer that has a use inside the containing op.
1067  SetVector<Operation *> remainingProducers(llvm::from_range, producerOps);
1068  auto getNextProducer = [&]() -> FailureOr<Operation *> {
1069  for (const auto &it : enumerate(remainingProducers)) {
1070  Operation *producerOp = it.value();
1071  // The containing op may be a user of producerOp: use isAncestor.
1072  int64_t numUsesInContainingOp =
1073  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
1074  return containingOp->isAncestor(op);
1075  });
1076  // TODO: When resolving the TODO below (no duplicate ops), take an op
1077  // that has no use among the remaining producers. This is a topological
1078  // sorting.
1079  if (numUsesInContainingOp > 0) {
1080  if (numUsesInContainingOp == 1)
1081  remainingProducers.erase(remainingProducers.begin() + it.index());
1082  return producerOp;
1083  }
1084  }
1085  return failure();
1086  };
1087 
1088  while (!remainingProducers.empty()) {
1089  auto nextProducer = getNextProducer();
1090  if (failed(nextProducer)) {
1091  auto diag = mlir::emitSilenceableFailure(getLoc())
1092  << "could not find next producer to fuse into container";
1093  diag.attachNote(containingOp->getLoc()) << "containing op";
1094  return diag;
1095  }
1096 
1097  Operation *producerOp = *nextProducer;
1098 
1099  // Default diagnostic, to be complemented with more failure information.
1101  diag << "could not fuse " << *producerOp << " into " << *containingOp;
1102 
1103  // TODO: If there are multiple uses of the producer in the containing op,
1104  // we currently tile/clone the op multiple times (once per use). In some
1105  // cases, we can tile/clone once and reuse the value for each use.
1106  // Futhermore, producers should then be traversed according to a
1107  // topological sorting.
1108  auto [tiledOps, newContainingOp] =
1109  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1110  if (!tiledOps.empty()) {
1111  LDBG() << "\nFused a direct extract use\n" << *containingOp;
1112  fusedOps.append(tiledOps);
1113  if (newContainingOp) {
1114  // Update handles associated with the containing op so we don't need to
1115  // invalidate them. This is a hack to support better composability
1116  // between tiling and fusion while a proper mechanism is being
1117  // investigated.
1118  //
1119  // DO NOT replicate this elsewhere unless you understand what you are
1120  // doing.
1121  LogicalResult replacementStatus =
1122  rewriter.notifyPayloadOperationReplaced(containingOp,
1123  newContainingOp);
1124  (void)replacementStatus;
1125  assert(succeeded(replacementStatus) &&
1126  "unable to update transform state mapping");
1127  rewriter.eraseOp(containingOp);
1128  containingOp = newContainingOp;
1129  }
1130  continue;
1131  }
1132 
1133  SmallVector<Operation *> tiledContainingOpOperand =
1135  rewriter, diag, producerOp, containingOp);
1136  if (!tiledContainingOpOperand.empty()) {
1137  LDBG() << "\nFused an extract use through block argument\n"
1138  << *containingOp;
1139  fusedOps.append(tiledContainingOpOperand);
1140  continue;
1141  }
1142 
1143  Operation *cloned =
1144  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1145  if (cloned) {
1146  LDBG() << "\nFused an use by cloning\n" << *containingOp;
1147  fusedOps.push_back(cloned);
1148  continue;
1149  }
1151  }
1152 
1153  results.set(cast<OpResult>(getFusedOp()), fusedOps);
1154  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1156 }
1157 
1158 void transform::FuseIntoContainingOp::getEffects(
1160  consumesHandle(getProducerOpMutable(), effects);
1161  onlyReadsHandle(getContainingOpMutable(), effects);
1162  producesHandle(getOperation()->getOpResults(), effects);
1163  modifiesPayload(effects);
1164 }
1165 
1166 //===----------------------------------------------------------------------===//
1167 // GeneralizeOp
1168 //===----------------------------------------------------------------------===//
1169 
1171 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1172  LinalgOp target,
1174  transform::TransformState &state) {
1175  // Exit early if no transformation is needed.
1176  if (isa<GenericOp>(target)) {
1177  results.push_back(target);
1179  }
1180  rewriter.setInsertionPoint(target);
1181  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1182  if (succeeded(generic)) {
1183  results.push_back(generic->getOperation());
1185  }
1186  return emitDefaultSilenceableFailure(target);
1187 }
1188 
1189 //===----------------------------------------------------------------------===//
1190 // SpecializeOp
1191 //===----------------------------------------------------------------------===/
1192 
1194 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1195  LinalgOp target,
1197  transform::TransformState &state) {
1198  // Exit early if the operation is not a generic.
1199  if (!isa<GenericOp>(target)) {
1200  results.push_back(target);
1202  }
1203  rewriter.setInsertionPoint(target);
1204  FailureOr<LinalgOp> named =
1205  specializeGenericOp(rewriter, cast<GenericOp>(target));
1206  if (succeeded(named)) {
1207  results.push_back(named->getOperation());
1209  }
1210  return emitDefaultSilenceableFailure(target);
1211 }
1212 
1213 //===----------------------------------------------------------------------===//
1214 // InterchangeOp
1215 //===----------------------------------------------------------------------===//
1216 
1218 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1219  GenericOp target,
1221  transform::TransformState &state) {
1222  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1223  // Exit early if no transformation is needed.
1224  if (interchangeVector.empty()) {
1225  results.push_back(target);
1227  }
1228 
1229  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1230  if (interchangeVector.size() != numLoops) {
1231  return emitSilenceableError()
1232  << getIteratorInterchangeAttrName() << " has length ("
1233  << interchangeVector.size()
1234  << ") different from the number of loops in the target operation ("
1235  << numLoops << ")";
1236  }
1237  FailureOr<GenericOp> res = interchangeGenericOp(
1238  rewriter, target, SmallVector<unsigned>(interchangeVector));
1239  if (failed(res))
1240  return emitDefiniteFailure() << "failed to apply";
1241  results.push_back(res->getOperation());
1243 }
1244 
1245 LogicalResult transform::InterchangeOp::verify() {
1246  ArrayRef<int64_t> permutation = getIteratorInterchange();
1247  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1248  if (!std::is_permutation(sequence.begin(), sequence.end(),
1249  permutation.begin(), permutation.end())) {
1250  return emitOpError()
1251  << "expects iterator_interchange to be a permutation, found "
1252  << getIteratorInterchange();
1253  }
1254  return success();
1255 }
1256 
1257 //===----------------------------------------------------------------------===//
1258 // LinalgCopyToMemrefOp
1259 //===----------------------------------------------------------------------===//
1260 
1261 DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1262  transform::TransformRewriter &rewriter, Operation *targetOp,
1264  transform::TransformState &state) {
1265 
1266  // Check if the target can be converted.
1267  if (!isa<linalg::CopyOp>(targetOp)) {
1269  emitSilenceableError() << "only linalg.copy target ops are supported";
1270  diag.attachNote(targetOp->getLoc()) << "target op";
1271  return diag;
1272  }
1273 
1274  auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1275  if (!copyOp.hasPureBufferSemantics()) {
1277  emitSilenceableError()
1278  << "cannot transform a linalg.copy on tensors into a memref.copy";
1279  diag.attachNote(targetOp->getLoc()) << "target op";
1280  return diag;
1281  }
1282 
1283  SmallVector<Value> inputs = copyOp.getInputs();
1284  SmallVector<Value> outputs = copyOp.getOutputs();
1285  assert(inputs.size() == 1 && "expected linalg copy op with one input");
1286  assert(outputs.size() == 1 && "expected memref copy op with one output");
1287  Value input = inputs.front();
1288  Value output = outputs.front();
1289 
1290  // linalg.copy supports different element types on source/dest whereas
1291  // memref.copy does not, so we must check that the source and dest types can
1292  // be handled by memref.copy and otherwise reject the transformation.
1293  if (!isa<ShapedType>(input.getType())) {
1295  emitSilenceableError()
1296  << "cannot transform a linalg.copy which input has no shape";
1297  diag.attachNote(targetOp->getLoc()) << "target op";
1298  return diag;
1299  }
1300 
1301  // linalg.copy destination must be a shaped type.
1302  assert(isa<ShapedType>(output.getType()));
1303 
1304  if (cast<ShapedType>(input.getType()).getElementType() !=
1305  cast<ShapedType>(output.getType()).getElementType()) {
1307  emitSilenceableError()
1308  << "cannot transform a linalg.copy with different source and "
1309  "destination element types ";
1310  diag.attachNote(targetOp->getLoc()) << "target op";
1311  return diag;
1312  }
1313 
1314  // Target can be converted, do it.
1315  auto memrefCopyOp =
1316  rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1317 
1318  results.push_back(memrefCopyOp);
1320 }
1321 
1322 //===----------------------------------------------------------------------===//
1323 // LowerPackOp
1324 //===----------------------------------------------------------------------===//
1325 
1326 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1327  transform::TransformRewriter &rewriter, linalg::PackOp target,
1328  transform::ApplyToEachResultList &transformResults,
1329  transform::TransformState &state) {
1330  rewriter.setInsertionPoint(target);
1331  bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1332  FailureOr<LowerPackResult> res =
1333  lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1334  if (failed(res)) {
1335  return mlir::emitSilenceableFailure(target->getLoc())
1336  << "cannot lower to pad + expand + transpose";
1337  }
1338  transformResults.push_back(res->padOp);
1339  transformResults.push_back(res->expandShapeOp);
1340  transformResults.push_back(res->transposeOp);
1342 }
1343 
1344 //===----------------------------------------------------------------------===//
1345 // LowerUnPackOp
1346 //===----------------------------------------------------------------------===//
1347 
1348 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1349  transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1350  transform::ApplyToEachResultList &transformResults,
1351  transform::TransformState &state) {
1352  rewriter.setInsertionPoint(target);
1353  bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1354  FailureOr<LowerUnPackOpResult> res =
1355  lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1356  if (failed(res)) {
1358  emitSilenceableError()
1359  << "cannot lower to transpose + collapse + extract";
1360  diag.attachNote(target->getLoc()) << "target payload op";
1361  return diag;
1362  }
1363  transformResults.push_back(res->emptyOp);
1364  transformResults.push_back(res->transposeOp);
1365  transformResults.push_back(res->collapseShapeOp);
1366  transformResults.push_back(res->extractSliceOp);
1368 }
1369 
1370 //===---------------------------------------------------------------------===//
1371 // MatchOp
1372 //===---------------------------------------------------------------------===//
1373 
1374 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1375  Value target, ArrayRef<StringRef> opNames) {
1376  result.addOperands(target);
1377  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1378  builder.getStrArrayAttr(opNames));
1379  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1380 }
1381 
1382 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1383  TypeRange resultTypes, Value target,
1384  ArrayRef<StringRef> opNames) {
1385  result.addOperands(target);
1386  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1387  builder.getStrArrayAttr(opNames));
1388  result.addTypes(resultTypes);
1389 }
1390 
1392 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1393  transform::TransformResults &results,
1394  transform::TransformState &state) {
1395  llvm::StringSet<> strs;
1396  if (getOps().has_value())
1397  strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1398 
1399  auto payloadOps = state.getPayloadOps(getTarget());
1400  if (!llvm::hasSingleElement(payloadOps)) {
1401  return emitDefiniteFailure("requires exactly one target handle");
1402  }
1403 
1405  bool incorrectNumOperandTypes = false;
1406  auto matchFun = [&](Operation *op) {
1407  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1408  return;
1409 
1410  // Interfaces cannot be matched by name, just by ID.
1411  // So we specifically encode the interfaces we care about for this op.
1412  if (getInterface().has_value()) {
1413  auto iface = getInterface().value();
1414  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1415  !isa<LinalgOp>(op))
1416  return;
1417  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1418  !isa<TilingInterface>(op))
1419  return;
1420  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1421  !isa<LoopLikeOpInterface>(op))
1422  return;
1423  }
1424 
1425  // Check if all specified attributes match.
1426  if (getOpAttrs().has_value()) {
1427  DictionaryAttr opAttrs = getOpAttrs().value();
1428  for (NamedAttribute attr : opAttrs) {
1429  if (attr.getName() == getInterfaceAttrName() ||
1430  attr.getName() == getOpsAttrName())
1431  continue;
1432  if (!op->hasAttr(attr.getName()))
1433  return;
1434  if (op->getAttr(attr.getName()) != attr.getValue())
1435  return;
1436  }
1437  }
1438 
1439  if (getFilterResultType().has_value()) {
1440  Type t = getFilterResultType().value();
1441  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1442  return;
1443  }
1444 
1445  if (getFilterOperandTypes().has_value()) {
1446  mlir::ArrayAttr types = getFilterOperandTypes().value();
1447  auto operandTypes = op->getOperandTypes();
1448 
1449  if (types.size() == 1) {
1450  // All the operands must must be equal to the specified type
1451  auto typeattr =
1452  dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1453  Type t = cast<::mlir::Type>(typeattr.getValue());
1454  if (!llvm::all_of(op->getOperandTypes(),
1455  [&](Type operandType) { return operandType == t; }))
1456  return;
1457  } else {
1458  // The operand types must match all the types in the list (in the same
1459  // order in with they are specified)
1460  if (types.size() != operandTypes.size()) {
1461  incorrectNumOperandTypes = true;
1462  return;
1463  }
1464 
1465  for (auto [attr, operandType] :
1466  llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1467  auto typeattr = cast<mlir::TypeAttr>(attr);
1468  Type type = cast<::mlir::Type>(typeattr.getValue());
1469 
1470  if (type != operandType)
1471  return;
1472  }
1473  }
1474  }
1475 
1476  // All constraints are satisfied.
1477  res.push_back(op);
1478  return;
1479  };
1480 
1481  (*payloadOps.begin())->walk(matchFun);
1482  if (incorrectNumOperandTypes)
1483  return emitDefiniteFailure("If filter_operand_types contains more than a "
1484  "type, then it must contain as much types as "
1485  "the number of operands in the target ops");
1486  results.set(cast<OpResult>(getResult()), res);
1488 }
1489 
1490 //===---------------------------------------------------------------------===//
1491 // MultiTileSizesOp
1492 //===---------------------------------------------------------------------===//
1493 
1495  Type targetType, Type lowSizeType, Type,
1496  Type) {
1497  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1498 }
1499 
1500 static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1501  Type &targetType, Type &lowSizeType,
1502  Type &highSizeType,
1503  Type &splitPointType) {
1504  FunctionType funcType;
1505  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1506  if (failed(parser.parseType<FunctionType>(funcType)))
1507  return failure();
1508 
1509  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1510  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1511  "argument and one result";
1512  }
1513  targetType = funcType.getInput(0);
1514  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1515 
1516  return success();
1517 }
1518 
1519 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1520  transform::TransformRewriter &rewriter, LinalgOp target,
1522  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1523  if (target.hasDynamicShape()) {
1524  auto diag = emitSilenceableError()
1525  << "cannot compute parametric tile sizes for dynamically "
1526  "shaped payload op";
1527  diag.attachNote(target->getLoc()) << "payload op";
1528  return diag;
1529  }
1530 
1531  FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1532  target, getDimension(), getTargetSize(), getDivisor());
1533  if (failed(spec)) {
1534  return emitSilenceableError()
1535  << "failed to compute multi-size tiling sizes";
1536  }
1537 
1538  Builder builder(target.getContext());
1539  results.assign(llvm::map_range(
1540  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1541  spec->lowTileSize * spec->lowTripCount}),
1542  [&builder, this](int64_t value) {
1543  return builder.getIntegerAttr(
1544  cast<ParamType>(getLowSize().getType()).getType(), value);
1545  }));
1547  }
1548 
1549  OpBuilder builder(target.getContext());
1550  builder.setInsertionPoint(target);
1551  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1552  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1553  FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1554  builder, target, getDimension(), targetSize, divisor);
1555  if (failed(spec)) {
1556  return emitSilenceableError() << "could not generate tile size computation";
1557  }
1558 
1559  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1560  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1561  Operation *splitPoint =
1562  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1563  {spec->lowTileSize, spec->lowTripCount});
1564  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1565  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1566  assert(lowTileSize && highTileSize && splitPoint &&
1567  "tile sizes are not produced by operations");
1568  results.reserve(results.size() + 3);
1569  results.push_back(lowTileSize);
1570  results.push_back(highTileSize);
1571  results.push_back(splitPoint);
1573 }
1574 
1575 void transform::MultiTileSizesOp::getEffects(
1577  onlyReadsHandle(getTargetMutable(), effects);
1578  producesHandle(getOperation()->getOpResults(), effects);
1579  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1580  onlyReadsPayload(effects);
1581  else
1582  modifiesPayload(effects);
1583 }
1584 
1585 LogicalResult transform::MultiTileSizesOp::verify() {
1586  if (getLowSize().getType() != getHighSize().getType() ||
1587  getLowSize().getType() != getSplitPoint().getType()) {
1588  return emitOpError() << "expects all results type to be the same";
1589  }
1590  return success();
1591 }
1592 
1593 //===---------------------------------------------------------------------===//
1594 // PackOp
1595 //===---------------------------------------------------------------------===//
1596 
1597 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1598  Value target,
1599  ArrayRef<OpFoldResult> mixedPackedSizes) {
1600  SmallVector<int64_t> staticPackedSizes;
1601  SmallVector<Value> dynamicPackedSizes;
1602  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1603  staticPackedSizes);
1604  // Call the default builder which sets up the proper operands segment sizes
1605  // attributes for multiple variadic operands. In the absence of this, horrible
1606  // bugs ensue.
1607  Type linalgOpHType = transform::OperationType::get(
1608  builder.getContext(), GenericOp::getOperationName());
1609  build(builder, result,
1610  /*resultType=*/linalgOpHType,
1611  /*target=*/target,
1612  /*dynamic_sizes=*/dynamicPackedSizes,
1613  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1614 }
1615 
1616 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1617  Builder b(getContext());
1618  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1619 }
1620 
1622 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1623  transform::TransformResults &transformResults,
1624  transform::TransformState &state) {
1625  auto targetOps = state.getPayloadOps(getTarget());
1626  // If nothing to pack, propagate success.
1627  if (std::empty(targetOps)) {
1628  transformResults.set(cast<OpResult>(getPackedOp()),
1629  ArrayRef<Operation *>({}));
1631  }
1632  // Fail on multi-op handles.
1633  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1634  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1635  return emitSilenceableError()
1636  << "requires target to map to exactly 1 LinalgOp (got "
1637  << llvm::range_size(targetOps) << ")";
1638  }
1639  // Fail on mismatched number of pack sizes.
1640  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1641  return emitSilenceableError()
1642  << "requires number of packed sizes match the number of loops ("
1643  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1644  << ")";
1645  }
1646 
1647  // Unpack handles to constants or actual SSA index values.
1648  SmallVector<OpFoldResult> packedSizes;
1650  state, *this, packedSizes, getMixedPackedSizes());
1651 
1652  rewriter.setInsertionPoint(linalgOp);
1653  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1654  if (failed(maybeResult))
1655  return emitDefiniteFailure("data tiling failed");
1656 
1657  transformResults.set(cast<OpResult>(getPackedOp()),
1658  {maybeResult->packedLinalgOp.getOperation()});
1660 }
1661 
1662 void transform::PackOp::getEffects(
1664  transform::consumesHandle(getTargetMutable(), effects);
1665  transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1666  transform::producesHandle(getOperation()->getOpResults(), effects);
1667  transform::modifiesPayload(effects);
1668 }
1669 
1670 //===---------------------------------------------------------------------===//
1671 // PackGreedilyOp.
1672 //===---------------------------------------------------------------------===//
1673 
1674 LogicalResult transform::PackGreedilyOp::verify() {
1675  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1676  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1677  << " is not a valid permutation";
1678  }
1679  // TODO: relax to allow empty once we have another strategy than just matmul.
1680  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1681  for (auto [s, nmo] :
1682  llvm::zip_equal(getMixedMatmulPackedSizes(),
1683  getMatmulPaddedSizesNextMultipleOf())) {
1684  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1685  if (nmo != 0 &&
1686  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1687  return emitOpError() << "at most one of the packed_size and the "
1688  "padded_sizes_next_multiple_of can be nonzero "
1689  "for the matmul strategy";
1690  }
1691  }
1692  }
1693  return success();
1694 }
1695 
1697 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1698  transform::TransformResults &transformResults,
1699  transform::TransformState &state) {
1700  SmallVector<Operation *> results;
1701  for (Operation *op : state.getPayloadOps(getTarget())) {
1702  auto linalgOp = dyn_cast<LinalgOp>(op);
1703  if (!linalgOp)
1704  continue;
1705  // linalgOp will be replaced and the insertion point may be invalidated if
1706  // we set it before -> set it after.
1707  rewriter.setInsertionPointAfter(linalgOp);
1708  // Failing to pack greedily is perfectly fine.
1709  // In the future we will want to order packings according to some metric.
1710  FailureOr<PackResult> packResult = packMatmulGreedily(
1711  /*rewriter=*/rewriter,
1712  /*linalgOp=*/linalgOp,
1713  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1714  /*mnkPaddedSizesNextMultipleOf=*/
1715  getMatmulPaddedSizesNextMultipleOf(),
1716  /*mnkOrder=*/getMatmulInnerDimsOrder());
1717  if (succeeded(packResult)) {
1718  results.push_back(packResult->packedLinalgOp);
1719  continue;
1720  }
1721  results.push_back(linalgOp);
1722  }
1723  transformResults.set(cast<OpResult>(getPackedOp()), results);
1725 }
1726 
1727 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1728  Builder b(getContext());
1729  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1730  b);
1731 }
1732 
1733 void transform::PackGreedilyOp::getEffects(
1735  transform::consumesHandle(getTargetMutable(), effects);
1736  transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1737  transform::producesHandle(getOperation()->getOpResults(), effects);
1738  transform::modifiesPayload(effects);
1739 }
1740 
1741 //===---------------------------------------------------------------------===//
1742 // PackTransposeOp
1743 //===---------------------------------------------------------------------===//
1744 
1745 LogicalResult transform::PackTransposeOp::verify() {
1746  if (!isPermutationVector(getInnerPerm())) {
1747  return emitOpError() << getInnerPermAttrName()
1748  << " is not a valid permutation";
1749  }
1750  if (!isPermutationVector(getOuterPerm())) {
1751  return emitOpError() << getOuterPermAttrName()
1752  << " is not a valid permutation";
1753  }
1754  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1755  return emitOpError() << " at least one of " << getInnerPermAttrName()
1756  << " or " << getOuterPermAttrName()
1757  << " must be specified";
1758  }
1759  return success();
1760 }
1761 
1762 namespace {
1763 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1764 } // namespace
1765 
1766 /// Return true if `permutation` is a valid permutation of the
1767 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1768 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1769 /// This is the case when the `permutation` rank matches the rank expected by
1770 /// `op` and `permutation` is itself a permutation vector.
1771 /// Return true if either `op` or `permutation` are empty to allow a simpler
1772 /// polymorphic implementation.
1773 template <typename RelayoutOpTy>
1775  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1776  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1777  static_assert(
1778  llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1779  "applies to only pack or unpack operations");
1780  if (!op || permutation.empty())
1781  return true;
1782  size_t innerRank = op.getInnerDimsPos().size();
1783  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1784  return permutation.size() == innerRank && isPermutationVector(permutation);
1785  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1786  // Don't rely on it.
1787  if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1788  return permutation.size() == op.getSourceRank() &&
1789  isPermutationVector(permutation);
1790  }
1791  return permutation.size() == op.getDestRank() &&
1792  isPermutationVector(permutation);
1793 }
1794 
1796 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1797  transform::TransformResults &transformResults,
1798  transform::TransformState &state) {
1799  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1800  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1801  // Step 1. If nothing to pack, propagate success.
1802  if (std::empty(packOrUnpackOps)) {
1803  transformResults.set(cast<OpResult>(getPackedOp()), {});
1804  transformResults.set(cast<OpResult>(getPackOp()), {});
1805  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1807  }
1808 
1809  // Step 2. Bunch of runtime sanity check and error messages.
1810  // Step 2.1. Fail on multi-op handles.
1811  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1812  !llvm::hasSingleElement(linalgOps)) {
1813  return emitSilenceableError()
1814  << "requires target to map to exactly 1 "
1815  "packing op and 1 packed op ("
1816  << "got " << llvm::range_size(packOrUnpackOps) << " and "
1817  << llvm::range_size(linalgOps) << ")";
1818  }
1819 
1820  // Step 2.2. Fail on wrong type.
1821  auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1822  auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1823  if ((!packOp && !unPackOp)) {
1824  return emitSilenceableError() << "requires target to map to a "
1825  "linalg.pack or linalg.unpack";
1826  }
1827  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1828  if (!linalgOpTarget)
1829  return emitSilenceableError() << "requires a LinalgOp target";
1830 
1831  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1832  LinalgOp linalgOp;
1833  if (packOp && packOp.getResult().hasOneUse())
1834  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1835  else if (unPackOp)
1836  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1837  if (linalgOp != linalgOpTarget) {
1838  auto errorMsg =
1839  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1840  : StringLiteral{"not produced by the LinalgOp target"};
1841  return emitSilenceableError() << errorMsg;
1842  }
1843 
1844  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1845  // PackOp.
1846  if (unPackOp) {
1847  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1848  OpOperand *packUse = linalgOp.getDpsInitOperand(
1849  cast<OpResult>(unPackOp.getSource()).getResultNumber());
1850  packOp = packUse->get().getDefiningOp<linalg::PackOp>();
1851  if (!packOp || !packOp.getResult().hasOneUse())
1852  return emitSilenceableError() << "could not find matching pack op";
1853  }
1854 
1855  // Step 2.5. Fail if any permutation does not validate.
1856  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1857  ArrayRef<int64_t> perm =
1858  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1859  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1860  ? StringLiteral{"invalid outer_perm"}
1861  : StringLiteral{"invalid inner_perm"};
1862  if (!isValidPackingPermutation(packOp, perm, permType) ||
1863  !isValidPackingPermutation(unPackOp, perm, permType)) {
1864  Operation *packOrUnpackOp =
1865  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1866  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1867  }
1868  }
1869 
1870  // From here on, packOp and linalgOp are always present, unPackOp may or may
1871  // not be present.
1872  assert(packOp && linalgOp && "unexpected null op");
1873 
1874  // Step 3. Actually transpose the ops.
1875  FailureOr<PackTransposeResult> res = packTranspose(
1876  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1877  // Preconditions have been checked, it is an error to fail here.
1878  assert(succeeded(res) && "unexpected packTranspose failure");
1879 
1880  // Step 4. Return results.
1881  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1882  transformResults.set(cast<OpResult>(getPackedOp()),
1883  {res->transposedLinalgOp});
1884  if (unPackOp) {
1885  transformResults.set(cast<OpResult>(getUnPackOp()),
1886  {res->transposedUnPackOp});
1887  } else {
1888  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1889  }
1890 
1892 }
1893 
1894 //===---------------------------------------------------------------------===//
1895 // PadOp
1896 //===---------------------------------------------------------------------===//
1897 
1898 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1899  ArrayRef<int64_t> paddingDimensions,
1900  ArrayRef<int64_t> padToMultipleOf,
1901  ArrayRef<int64_t> nofoldFlags,
1902  ArrayRef<Attribute> transposePaddings,
1903  StringRef copyBackOp,
1904  bool usePrescribedTensorShapes) {
1905  auto resultType = transform::AnyOpType::get(b.getContext());
1906  return build(/*odsBuilder=*/b,
1907  /*result=*/result,
1908  /*types=*/TypeRange{resultType, resultType},
1909  /*target=*/target,
1910  /*padding_values=*/ArrayAttr(), // let inference handle this
1911  /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
1912  /*pad_to_multiple_of=*/ValueRange{},
1913  /*padToMultipleOf=*/
1914  (padToMultipleOf.empty()
1915  ? DenseI64ArrayAttr()
1916  : b.getDenseI64ArrayAttr(padToMultipleOf)),
1917  /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
1918  /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
1919  /*copy_back_op=*/b.getStringAttr(copyBackOp),
1920  /*use_prescribed_tensor_shapes=*/
1921  usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
1922 }
1923 
1924 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1925  ArrayRef<int64_t> paddingDimensions,
1926  ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1927  ArrayRef<int64_t> nofoldFlags,
1928  ArrayRef<Attribute> transposePaddings,
1929  StringRef copyBackOp,
1930  bool usePrescribedTensorShapes) {
1931  auto resultType = transform::AnyOpType::get(b.getContext());
1932  SmallVector<int64_t> staticPadToMultipleOf;
1933  SmallVector<Value> dynamicPadToMultipleOf;
1934  dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
1935  staticPadToMultipleOf);
1936  return build(/*odsBuilder=*/b,
1937  /*result=*/result,
1938  /*types=*/TypeRange{resultType, resultType},
1939  /*target=*/target,
1940  /*padding_values=*/ArrayAttr(), // let inference handle this
1941  /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
1942  /*pad_to_multiple_of=*/dynamicPadToMultipleOf,
1943  /*padToMultipleOf=*/staticPadToMultipleOf,
1944  /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
1945  /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
1946  /*copy_back_op=*/copyBackOp,
1947  /*use_prescribed_tensor_shapes=*/usePrescribedTensorShapes);
1948 }
1949 
1950 void PadOp::getEffects(
1952  consumesHandle(getTargetMutable(), effects);
1953  onlyReadsHandle(getPadToMultipleOfMutable(), effects);
1954  producesHandle(getOperation()->getOpResults(), effects);
1955  modifiesPayload(effects);
1956 }
1957 
1958 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1959  Builder b(getContext());
1960  return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1961 }
1962 
1964 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1965  transform::TransformResults &results,
1966  transform::TransformState &state) {
1967  auto transformOp = cast<TransformOpInterface>(getOperation());
1968  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1969 
1970  for (Operation *target : state.getPayloadOps(getTarget())) {
1971  auto linalgTarget = dyn_cast<LinalgOp>(target);
1972  if (!linalgTarget) {
1973  auto diag = emitSilenceableError() << "expected LinalgOp target";
1974  diag.attachNote(target->getLoc()) << "target op";
1975  return diag;
1976  }
1977 
1978  // Convert the integer packing flags to booleans.
1979  SmallVector<bool> nofoldFlags;
1980  for (int64_t packPadding :
1981  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1982  nofoldFlags.push_back(static_cast<bool>(packPadding));
1983 
1984  // Convert the padding values to attributes.
1985  SmallVector<Attribute> paddingValues;
1986  for (auto const &[untypedAttr, elementOrTensorType] :
1987  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1988 
1989  if (isa<ub::PoisonAttr>(untypedAttr)) {
1990  paddingValues.push_back(untypedAttr);
1991  continue;
1992  }
1993  auto attr = dyn_cast<TypedAttr>(untypedAttr);
1994  if (!attr) {
1995  emitOpError("expects padding values to be typed attributes or poison");
1997  }
1998  Type elementType = getElementTypeOrSelf(elementOrTensorType);
1999  // Try to parse string attributes to obtain an attribute of element type.
2000  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2001  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2002  stringAttr, getContext(), elementType,
2003  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2004  if (!parsedAttr || parsedAttr.getType() != elementType) {
2005  auto diag = this->emitOpError("expects a padding that parses to ")
2006  << elementType << ", got " << untypedAttr;
2007  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2009  }
2010  paddingValues.push_back(parsedAttr);
2011  continue;
2012  }
2013  // Otherwise, add the attribute directly.
2014  if (attr.getType() != elementType) {
2015  auto diag = this->emitOpError("expects a padding value of type ")
2016  << elementType << ", got " << attr;
2017  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2019  }
2020  paddingValues.push_back(attr);
2021  }
2022 
2023  // Extract the transpose vectors.
2024  SmallVector<SmallVector<int64_t>> transposePaddings;
2025  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2026  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2027  cast<ArrayAttr>(transposeVector)));
2028 
2029  LinalgOp paddedOp;
2031  options.paddingDimensions =
2032  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2033 
2034  SmallVector<int64_t> padToMultipleOf;
2036  state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2037  if (!status.succeeded())
2038  return status;
2039  if (padToMultipleOf.empty())
2040  padToMultipleOf =
2041  SmallVector<int64_t>(options.paddingDimensions.size(), 1);
2042 
2043  options.padToMultipleOf = padToMultipleOf;
2044  options.paddingValues = paddingValues;
2045  options.nofoldFlags = nofoldFlags;
2046  if (getCopyBackOp() ==
2047  bufferization::MaterializeInDestinationOp::getOperationName()) {
2050  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2052  } else if (getCopyBackOp() == kCopyOpNone) {
2054  } else {
2055  llvm_unreachable("unsupported copy_back op");
2056  }
2057  // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2058  bool irChanged = false;
2059  if (getUsePrescribedTensorShapes() &&
2060  linalgTarget.hasPureTensorSemantics()) {
2061  OpBuilder::InsertionGuard g(rewriter);
2062  rewriter.setInsertionPoint(linalgTarget);
2063  for (OpOperand &operand : linalgTarget->getOpOperands()) {
2064  for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2065  if (ShapedType::isStatic(dim))
2066  continue;
2067  options.setSizeToPadTo(operand.getOperandNumber(), i,
2068  tensor::getMixedSize(rewriter,
2069  operand.get().getLoc(),
2070  operand.get(), i));
2071  irChanged = true;
2072  }
2073  }
2074  }
2075 
2076  SmallVector<Value> replacements;
2077  SmallVector<tensor::PadOp> newPadOps;
2078  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
2079  replacements, newPadOps))) {
2080  if (irChanged) {
2081  auto diag = emitDefiniteFailure() << "failed to pad op";
2082  diag.attachNote(target->getLoc()) << "target op";
2083  return diag;
2084  }
2085  auto diag = emitSilenceableError() << "failed to pad op";
2086  diag.attachNote(target->getLoc()) << "target op";
2087  return diag;
2088  }
2089 
2090  // We need to perform our own replacement here because this API is still
2091  // used in patterns that "pad and hoist", for which the replacement values
2092  // need to be different.
2093  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2094  // that we have more composable abstractions.
2095  rewriter.replaceOp(linalgTarget, replacements);
2096  paddedOps.push_back(paddedOp);
2097  padOps.append(newPadOps.begin(), newPadOps.end());
2098  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2099  for (Value v : replacements) {
2100  Operation *copyBackOp = v.getDefiningOp();
2101  if (!llvm::is_contained(copyBackOps, copyBackOp))
2102  copyBackOps.push_back(copyBackOp);
2103  }
2104  }
2105  }
2106 
2107  results.set(cast<OpResult>(getPadded()), paddedOps);
2108  results.set(cast<OpResult>(getPad()), padOps);
2109  results.set(cast<OpResult>(getCopy()), copyBackOps);
2111 }
2112 
2113 LogicalResult transform::PadOp::verify() {
2114  SmallVector<int64_t> nofoldFlags =
2115  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2116  if (any_of(nofoldFlags, [](int64_t packPadding) {
2117  return packPadding != 0 && packPadding != 1;
2118  })) {
2119  return emitOpError()
2120  << "expects nofold_flags to contain booleans (0/1), found "
2121  << getNofoldFlags();
2122  }
2123 
2124  SmallVector<int64_t> paddingDimensions =
2125  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2126  if (any_of(paddingDimensions,
2127  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2128  return emitOpError() << "expects padding_dimensions to contain positive "
2129  "integers, found "
2130  << getPaddingDimensions();
2131  }
2132  if (!getMixedPadToMultipleOf().empty()) {
2133  if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2134  return emitOpError() << "expects as many multiples as padding_dimensions";
2135  }
2136  }
2137  ArrayAttr transposes = getTransposePaddings();
2138  for (Attribute attr : transposes) {
2139  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2140  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2141  if (!std::is_permutation(sequence.begin(), sequence.end(),
2142  transpose.begin(), transpose.end())) {
2143  return emitOpError()
2144  << "expects transpose_paddings to be a permutation, found "
2145  << attr;
2146  }
2147  }
2148  if (getCopyBackOp() !=
2149  bufferization::MaterializeInDestinationOp::getOperationName() &&
2150  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2151  getCopyBackOp() != kCopyOpNone)
2152  return emitOpError() << "invalid copy_back_op";
2153  return success();
2154 }
2155 
2156 //===---------------------------------------------------------------------===//
2157 // PadTilingInterfaceOp
2158 //===---------------------------------------------------------------------===//
2159 
2160 void transform::PadTilingInterfaceOp::build(OpBuilder &b,
2161  OperationState &result,
2162  Value target,
2163  ArrayRef<int64_t> paddingSizes,
2164  bool padToMultipleOf) {
2165  auto resultType = transform::AnyOpType::get(b.getContext());
2166  return build(/*odsBuilder=*/b,
2167  /*result=*/result,
2168  /*types=*/TypeRange{resultType, resultType},
2169  /*target=*/target,
2170  /*padding_values=*/ArrayAttr(), // let inference handle this
2171  /*padding_sizes=*/ValueRange{},
2172  /*paddingSizes=*/
2173  (paddingSizes.empty() ? DenseI64ArrayAttr()
2174  : b.getDenseI64ArrayAttr(paddingSizes)),
2175  /*pad_to_multiple_of=*/
2176  padToMultipleOf ? b.getUnitAttr() : nullptr);
2177 }
2178 
2179 void transform::PadTilingInterfaceOp::build(
2180  OpBuilder &b, OperationState &result, Value target,
2181  ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
2182  auto resultType = transform::AnyOpType::get(b.getContext());
2183  SmallVector<int64_t> staticPaddingSizes;
2184  SmallVector<Value> dynamicPaddingSizes;
2185  dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
2186  staticPaddingSizes);
2187  return build(/*odsBuilder=*/b,
2188  /*result=*/result,
2189  /*types=*/TypeRange{resultType, resultType},
2190  /*target=*/target,
2191  /*padding_values=*/ArrayAttr(), // let inference handle this
2192  /*padding_sizes=*/dynamicPaddingSizes,
2193  /*paddingSizes=*/staticPaddingSizes,
2194  /*usePrescribedTensorShapes=*/padToMultipleOf);
2195 }
2196 
2197 void transform::PadTilingInterfaceOp::getEffects(
2199  consumesHandle(getTargetMutable(), effects);
2200  onlyReadsHandle(getPaddingSizesMutable(), effects);
2201  producesHandle(getOperation()->getOpResults(), effects);
2202  modifiesPayload(effects);
2203 }
2204 
2206 transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2207  Builder b(getContext());
2208  return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2209 }
2210 
2212 transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2213  transform::TransformResults &results,
2214  transform::TransformState &state) {
2215  SmallVector<Operation *> paddedOps, padOps;
2216 
2217  for (Operation *target : state.getPayloadOps(getTarget())) {
2218  auto targetOp = dyn_cast<TilingInterface>(target);
2219  if (!targetOp) {
2220  auto diag = emitSilenceableError() << "expected TilingInterface target";
2221  diag.attachNote(target->getLoc()) << "target op";
2222  return diag;
2223  }
2224 
2225  // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
2226  // loopsToOperand map / C++ APIs to compute the effect of padding on
2227  // operands.
2228  if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2229  auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
2230  "supported atm";
2231  diag.attachNote(target->getLoc()) << "target op";
2232  return diag;
2233  }
2234 
2235  // Convert the padding values to attributes.
2236  SmallVector<Attribute> paddingValues;
2237  for (auto const &[untypedAttr, elementOrTensorType] :
2238  llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2239  auto attr = dyn_cast<TypedAttr>(untypedAttr);
2240  Type elementType = getElementTypeOrSelf(elementOrTensorType);
2241 
2242  if (isa<ub::PoisonAttr>(untypedAttr)) {
2243  paddingValues.push_back(untypedAttr);
2244  continue;
2245  }
2246  if (!attr) {
2247  emitOpError("expects padding values to be typed attributes or poison");
2249  }
2250  // Try to parse string attributes to obtain an attribute of element type.
2251  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2252  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2253  stringAttr, getContext(), elementType,
2254  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2255  if (!parsedAttr || parsedAttr.getType() != elementType) {
2256  auto diag = this->emitOpError("expects a padding that parses to ")
2257  << elementType << ", got " << attr;
2258  diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2260  }
2261  paddingValues.push_back(parsedAttr);
2262  continue;
2263  }
2264  // Otherwise, add the attribute directly.
2265  if (attr.getType() != elementType) {
2266  auto diag = this->emitOpError("expects a padding value of type ")
2267  << elementType << ", got " << attr;
2268  diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2270  }
2271  paddingValues.push_back(attr);
2272  }
2273 
2274  // Set options.
2275  TilingInterface paddedOp;
2277  options.setPaddingValues(paddingValues)
2278  .setPaddingSizes(getMixedPaddingSizes())
2279  .setPadToMultipleOf(getPadToMultipleOf());
2280 
2281  // Apply padding.
2282  SmallVector<tensor::PadOp> newPadOps;
2283  FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
2284  rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
2285  newPadOps);
2286  if (failed(maybePaddedOp)) {
2287  auto diag = emitSilenceableError() << "failed to pad op";
2288  diag.attachNote(target->getLoc()) << "target op";
2289  return diag;
2290  }
2291 
2292  // Set transform results.
2293  paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
2294  padOps.append(newPadOps.begin(), newPadOps.end());
2295  }
2296 
2297  results.set(cast<OpResult>(getPadded()), paddedOps);
2298  results.set(cast<OpResult>(getPad()), padOps);
2300 }
2301 
2302 LogicalResult transform::PadTilingInterfaceOp::verify() { return success(); }
2303 
2304 //===---------------------------------------------------------------------===//
2305 // HoistPadOp
2306 //===---------------------------------------------------------------------===//
2307 
2308 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2309  transform::TransformRewriter &rewriter,
2310  transform::TransformResults &transformResults,
2311  transform::TransformState &state) {
2312  auto targetOps = state.getPayloadOps(getTarget());
2313  auto loopOps = state.getPayloadOps(getLoop());
2314  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2315  return emitDefiniteFailure()
2316  << "requires exactly one target and one loop handle (got "
2317  << llvm::range_size(targetOps) << " and "
2318  << llvm::range_size(loopOps) << ")";
2319  }
2320 
2321  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2322  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2323  if (!padOp || !loopOp)
2324  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2325 
2326  FailureOr<linalg::detail::PackingResult> result =
2327  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2328  getTranspose());
2329  if (failed(result))
2330  return emitDefiniteFailure() << "could not build packing loop nest";
2331 
2332  if (result->clonedLoopIvs.empty()) {
2333  transformResults.set(cast<OpResult>(getPackingLoop()),
2334  {result->hoistedPadOp.getOperation()});
2336  }
2337  auto outerPackedLoop =
2338  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2339  transformResults.set(cast<OpResult>(getPackingLoop()),
2340  {outerPackedLoop.getOperation()});
2342 }
2343 
2345  ArrayRef<int64_t> transpose = getTranspose();
2346  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2347  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2348  transpose.end())) {
2349  return emitOpError() << "expects transpose to be a permutation, found "
2350  << getTranspose();
2351  }
2352  return success();
2353 }
2354 
2355 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2357  transform::onlyReadsHandle(getTargetMutable(), effects);
2358  transform::onlyReadsHandle(getLoopMutable(), effects);
2359  transform::producesHandle(getOperation()->getOpResults(), effects);
2360  transform::modifiesPayload(effects);
2361 }
2362 
2364 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2365  tensor::PadOp target,
2367  transform::TransformState &state) {
2368  tensor::PadOp hoistedPadOp;
2369  SmallVector<TransposeOp> transposeOps;
2370  FailureOr<Value> result =
2371  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2372  hoistedPadOp, transposeOps);
2373  if (succeeded(result)) {
2374  // We need to perform our own replacement here because this API is still
2375  // used in patterns that "pad and hoist", for which the replacement values
2376  // need to be different.
2377  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2378  // that we have more composable abstractions.
2379  rewriter.replaceOp(target, *result);
2380  results.push_back(hoistedPadOp);
2382  }
2383  return emitDefaultSilenceableFailure(target);
2384 }
2385 
2386 LogicalResult transform::HoistPadOp::verify() {
2387  ArrayRef<int64_t> transpose = getTranspose();
2388  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2389  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2390  transpose.end())) {
2391  return emitOpError() << "expects transpose to be a permutation, found "
2392  << getTranspose();
2393  }
2394  return success();
2395 }
2396 
2397 //===----------------------------------------------------------------------===//
2398 // PromoteOp
2399 //===----------------------------------------------------------------------===//
2400 
2402 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2403  LinalgOp target,
2405  transform::TransformState &state) {
2406  LinalgPromotionOptions promotionOptions;
2407  if (!getOperandsToPromote().empty())
2408  promotionOptions = promotionOptions.setOperandsToPromote(
2409  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2410  if (getUseFullTilesByDefault())
2411  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2412  getUseFullTilesByDefault());
2413  if (getUseOriginalSubviewSize())
2414  promotionOptions =
2415  promotionOptions.setUseOriginalSubviewSize(getUseOriginalSubviewSize());
2416  if (getUseAlloca())
2417  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2418  if (!getUseFullTileBuffers().empty())
2419  promotionOptions = promotionOptions.setUseFullTileBuffers(
2420  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2421  if (getAlignment().has_value())
2422  promotionOptions = promotionOptions.setAlignment(*getAlignment());
2423  if (getMemorySpace().has_value())
2424  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2425 
2426  if (getMapping().has_value()) {
2427  // The mapping should only contain an element
2428  auto mapping = *getMapping();
2429  if (mapping.size() > 1)
2430  return emitDefaultDefiniteFailure(target);
2431 
2432  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2433 
2434  if (addressSpace.getAddressSpace() ==
2435  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2436  promotionOptions =
2437  promotionOptions
2441  .setUseFullTileBuffers({false, false});
2442  } else if (addressSpace.getAddressSpace() ==
2443  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2444  promotionOptions =
2445  promotionOptions
2449  .setUseFullTileBuffers({false, false});
2450  } else {
2451  return emitDefaultDefiniteFailure(target);
2452  }
2453  }
2454 
2455  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2456  return emitDefaultDefiniteFailure(target);
2457 
2458  rewriter.setInsertionPoint(target);
2459  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2460  if (failed(res))
2461  return emitDefaultDefiniteFailure(target);
2462  results.push_back(target);
2464 }
2465 
2466 //===----------------------------------------------------------------------===//
2467 // ReplaceOp
2468 //===----------------------------------------------------------------------===//
2469 
2471 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2472  TransformResults &transformResults,
2473  TransformState &state) {
2474  auto payload = state.getPayloadOps(getTarget());
2475 
2476  // Check for invalid targets.
2477  for (Operation *target : payload) {
2478  if (target->getNumOperands() > 0)
2479  return emitDefiniteFailure() << "expected target without operands";
2480  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2481  target->getNumRegions() > 0)
2482  return emitDefiniteFailure()
2483  << "expected target that is isolated from above";
2484  }
2485 
2486  // Clone and replace.
2487  Operation *pattern = &getBodyRegion().front().front();
2488  SmallVector<Operation *> replacements;
2489  for (Operation *target : payload) {
2490  if (getOperation()->isAncestor(target))
2491  continue;
2492  rewriter.setInsertionPoint(target);
2493  Operation *replacement = rewriter.clone(*pattern);
2494  rewriter.replaceOp(target, replacement->getResults());
2495  replacements.push_back(replacement);
2496  }
2497  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2499 }
2500 
2501 void transform::ReplaceOp::getEffects(
2503  consumesHandle(getTargetMutable(), effects);
2504  producesHandle(getOperation()->getOpResults(), effects);
2505  modifiesPayload(effects);
2506 }
2507 
2508 LogicalResult transform::ReplaceOp::verify() {
2509  if (!getBodyRegion().hasOneBlock())
2510  return emitOpError() << "expected one block";
2511  if (std::distance(getBodyRegion().front().begin(),
2512  getBodyRegion().front().end()) != 1)
2513  return emitOpError() << "expected one operation in block";
2514  Operation *replacement = &getBodyRegion().front().front();
2515  if (replacement->getNumOperands() > 0)
2516  return replacement->emitOpError()
2517  << "expected replacement without operands";
2518  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2519  replacement->getNumRegions() > 0)
2520  return replacement->emitOpError()
2521  << "expect op that is isolated from above";
2522  return success();
2523 }
2524 
2525 //===----------------------------------------------------------------------===//
2526 // ScalarizeOp
2527 //===----------------------------------------------------------------------===//
2528 
2530 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2531  LinalgOp target,
2533  transform::TransformState &state) {
2534  scf::SCFTilingOptions tilingOptions;
2535  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2536  SmallVector<OpFoldResult> tileSizes;
2537  Location loc = target.getLoc();
2538  SmallVector<OpFoldResult> allShapeSizes =
2539  target.createFlatListOfOperandDims(b, loc);
2540  AffineMap map = target.getShapesToLoopsMap();
2541  if (!map)
2542  return tileSizes;
2543  SmallVector<OpFoldResult> shapeSizes =
2545  allShapeSizes);
2546  // If the shape size is dynamic, tile by 1.
2547  // Otherwise, do not tile (i.e. tile size 0).
2548  for (OpFoldResult shapeSize : shapeSizes) {
2549  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2550  : b.getIndexAttr(1));
2551  }
2552  return tileSizes;
2553  });
2554  rewriter.setInsertionPoint(target);
2555  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2556  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2557  if (failed(maybeTilingResult))
2558  return emitDefaultDefiniteFailure(target);
2559 
2560  if (target->getNumResults())
2561  rewriter.replaceOp(target, maybeTilingResult->replacements);
2562  else
2563  rewriter.eraseOp(target);
2564 
2565  results.reserve(maybeTilingResult->tiledOps.size());
2566  for (Operation *tiled : maybeTilingResult->tiledOps)
2567  results.push_back(tiled);
2569 }
2570 
2571 //===----------------------------------------------------------------------===//
2572 // ConvertToLoopsOp
2573 //===----------------------------------------------------------------------===//
2574 
2576 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2577  transform::TransformResults &results,
2578  transform::TransformState &state) {
2580  for (Operation *target : state.getPayloadOps(getTarget())) {
2581  auto tilingOp = dyn_cast<TilingInterface>(*target);
2582  if (!tilingOp) {
2584  emitSilenceableError()
2585  << "expected the payload to implement TilingInterface";
2586  diag.attachNote(target->getLoc()) << "payload op";
2587  return diag;
2588  }
2589  rewriter.setInsertionPoint(target);
2590  FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2591  scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2592  if (failed(generatedLoops))
2593  return emitDefaultDefiniteFailure(target);
2594  for (scf::ForOp &loop : *generatedLoops) {
2595  loops.push_back(loop.getOperation());
2596  }
2597  rewriter.eraseOp(target);
2598  }
2599  results.set(cast<OpResult>(getResult()), loops);
2601 }
2602 
2603 //===----------------------------------------------------------------------===//
2604 // RewriteInDestinationPassingStyleOp
2605 //===----------------------------------------------------------------------===//
2606 
2608 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2609  transform::TransformRewriter &rewriter, Operation *target,
2611  transform::TransformState &state) {
2612  rewriter.setInsertionPoint(target);
2613  FailureOr<Operation *> maybeResult =
2615  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2616  [&rewriter](auto op) {
2617  return rewriteInDestinationPassingStyle(rewriter, op);
2618  });
2619  if (failed(maybeResult))
2620  return emitDefaultSilenceableFailure(target);
2621  results.push_back(*maybeResult);
2623 }
2624 
2625 //===----------------------------------------------------------------------===//
2626 // SplitOp
2627 //===----------------------------------------------------------------------===//
2628 
2630 SplitOp::apply(transform::TransformRewriter &rewriter,
2631  TransformResults &results, TransformState &state) {
2632  // Collect the dynamic split points if provided.
2633  SmallVector<Operation *> payload =
2634  llvm::to_vector(state.getPayloadOps(getTarget()));
2635 
2636  bool isMultiwaySplit = getMultiway();
2637 
2638  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2639  return mlir::emitSilenceableFailure(getLoc())
2640  << "requires exactly one target when "
2641  "multiway split is enabled (got "
2642  << llvm::range_size(payload) << ")";
2643  }
2644 
2645  SmallVector<OpFoldResult> chunkSizes;
2646 
2647  if (!isMultiwaySplit)
2648  chunkSizes.reserve(payload.size());
2649 
2650  if (getDynamicChunkSizes()) {
2652  if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2653  chunkSizes = llvm::to_vector(llvm::map_range(
2654  state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2655  if (op->getNumResults() != 1 ||
2656  !op->getResult(0).getType().isIndex()) {
2657  diag = emitSilenceableError()
2658  << "expected dynamic split point handle to point to a "
2659  "single-result index-typed op";
2660  diag.attachNote(op->getLoc()) << "dynamic split point";
2661  }
2662  return OpFoldResult(op->getResult(0));
2663  }));
2664  } else {
2665  chunkSizes = llvm::to_vector(
2666  llvm::map_range(state.getParams(getDynamicChunkSizes()),
2667  [](Attribute attr) { return OpFoldResult(attr); }));
2668  }
2669  if (diag.isSilenceableFailure())
2670  return diag;
2671 
2672  // For multiway split, a single payload is expected to have multiple
2673  // split points.
2674  if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2675  return emitDefiniteFailure()
2676  << "expected the dynamic split point handle to point to as "
2677  "many operations ("
2678  << chunkSizes.size() << ") as the target handle ("
2679  << payload.size() << ")";
2680  }
2681  } else {
2682  chunkSizes.resize(payload.size(),
2683  rewriter.getIndexAttr(getStaticChunkSizes()));
2684  }
2685 
2686  auto checkStructuredOpAndDimensions =
2687  [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2688  if (!linalgOp) {
2689  auto diag = emitSilenceableError() << "only applies to structured ops";
2690  diag.attachNote(loc) << "target op";
2691  return diag;
2692  }
2693 
2694  if (getDimension() >= linalgOp.getNumLoops()) {
2695  auto diag = emitSilenceableError() << "dimension " << getDimension()
2696  << " does not exist in target op";
2697  diag.attachNote(loc) << "target op";
2698  return diag;
2699  }
2701  };
2702 
2703  auto checkFailureInSplitting =
2704  [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2705  if (hasFailed) {
2706  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2707  diag.attachNote(loc) << "target op";
2708  return diag;
2709  }
2711  };
2712 
2713  SmallVector<Operation *> opList;
2714  if (isMultiwaySplit) {
2715 
2716  // Split a single target operation at multiple points.
2717  TilingInterface head, tail;
2718  Operation *target = payload.front();
2719 
2720  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2721 
2722  // Check that the target is a valid LinalgOp with correct dimensions.
2724  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2725  if (diag.isSilenceableFailure())
2726  return diag;
2727 
2728  for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2729 
2730  if (idx > 0)
2731  target = tail.getOperation();
2732 
2733  if (!target)
2734  break;
2735 
2736  linalgOp = cast<LinalgOp>(target);
2737  Location loc = target->getLoc();
2738 
2739  rewriter.setInsertionPoint(linalgOp);
2740  std::tie(head, tail) = linalg::splitOp(
2741  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2742  getDimension(), chunkSize);
2743 
2744  // Propagate errors.
2746  checkFailureInSplitting(!head && !tail, loc);
2747  if (diag.isDefiniteFailure())
2748  return diag;
2749 
2750  opList.push_back(head.getOperation());
2751  }
2752 
2753  // Append any leftover parts to the end of the result list.
2754  if (tail)
2755  opList.push_back(tail.getOperation());
2756 
2757  } else {
2758  // Split each target operation.
2759  SmallVector<Operation *> first, second;
2760  Operation *noSecondPart = nullptr;
2761  for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2762  Operation *target = std::get<0>(pair);
2763  Location loc = target->getLoc();
2764  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2766  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2767 
2768  if (diag.isSilenceableFailure())
2769  return diag;
2770 
2771  rewriter.setInsertionPoint(linalgOp);
2772  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2773  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2774  getDimension(), std::get<1>(pair));
2775 
2776  // Propagate errors.
2777  DiagnosedSilenceableFailure diagSplit =
2778  checkFailureInSplitting(!first.back() && !second.back(), loc);
2779  if (diagSplit.isDefiniteFailure())
2780  return diag;
2781 
2782  // Do not add null second parts.
2783  if (!second.back()) {
2784  noSecondPart = target;
2785  second.pop_back();
2786  }
2787  }
2788 
2789  if (second.size() != first.size() && !second.empty()) {
2790  auto diag = emitSilenceableError()
2791  << "splitting does not produce the second part for a subset "
2792  "of targets";
2793  diag.attachNote()
2794  << "expected splitting to produce the second part of all "
2795  "or none of the targets";
2796  diag.attachNote(noSecondPart->getLoc())
2797  << "first target with no second part";
2798  return diag;
2799  }
2800 
2801  opList.append(first);
2802  if (second.size())
2803  opList.append(second);
2804  }
2805  results.set(cast<OpResult>(getSplitList()), opList);
2807 }
2808 
2809 void SplitOp::getEffects(
2811  consumesHandle(getTargetMutable(), effects);
2812  if (getDynamicChunkSizes())
2813  onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
2814  producesHandle(getOperation()->getOpResults(), effects);
2815  modifiesPayload(effects);
2816 }
2817 
2818 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2819  OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
2820  IntegerAttr staticChunkSizes;
2821  if (parser.parseOperand(target) || parser.parseKeyword("after"))
2822  return failure();
2823 
2824  OptionalParseResult dynamicPointParseResult =
2825  parser.parseOptionalOperand(dynamicChunkSizes);
2826  if (!dynamicPointParseResult.has_value()) {
2827  int64_t staticChunkSizesValue;
2828  if (failed(parser.parseInteger(staticChunkSizesValue)))
2829  return failure();
2830 
2831  staticChunkSizes =
2832  parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
2833  }
2834 
2835  Type targetType;
2836  if (parser.parseOptionalAttrDict(result.attributes) ||
2837  parser.parseColonType(targetType) ||
2838  parser.resolveOperand(target, targetType, result.operands)) {
2839  return failure();
2840  }
2841  if (dynamicPointParseResult.has_value()) {
2842  Type ChunkSizesType;
2843  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2844  parser.parseType(ChunkSizesType) ||
2845  parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
2846  result.operands)) {
2847  return failure();
2848  }
2849 
2850  staticChunkSizes =
2851  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2852  }
2853 
2854  result.addAttribute(
2855  SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
2856  staticChunkSizes);
2857  result.addTypes(targetType);
2858  return success();
2859 }
2860 
2861 void SplitOp::print(OpAsmPrinter &printer) {
2862  printer << " " << getTarget() << " after ";
2863  int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2864  if (staticChunkSize != ShapedType::kDynamic)
2865  printer << staticChunkSize;
2866  else
2867  printer << getDynamicChunkSizes();
2868  printer << " ";
2869  printer.printOptionalAttrDict(getOperation()->getAttrs(),
2870  {getStaticChunkSizesAttrName()});
2871  printer << " : " << getTarget().getType();
2872  if (staticChunkSize == ShapedType::kDynamic)
2873  printer << ", " << getDynamicChunkSizes().getType();
2874 }
2875 
2876 LogicalResult SplitOp::verify() {
2877  if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2878  (getDynamicChunkSizes() == nullptr)) {
2879  return emitOpError() << "expects either a dynamic or a static split "
2880  "point to be provided";
2881  }
2882  return success();
2883 }
2884 
2885 //===----------------------------------------------------------------------===//
2886 // SplitReductionOp
2887 //===----------------------------------------------------------------------===//
2888 
2889 void transform::SplitReductionOp::build(
2890  OpBuilder &builder, OperationState &result, Value target,
2891  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2892  bool useScalingAlgorithm, bool useAlloc) {
2893  MLIRContext *ctx = builder.getContext();
2894  result.addOperands(target);
2895  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2896  builder.getI64IntegerAttr(splitFactor));
2897  result.addAttribute(
2898  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2899  builder.getI64IntegerAttr(insertSplitDimension));
2900  if (innerParallel) {
2901  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2902  builder.getUnitAttr());
2903  }
2904  if (useScalingAlgorithm) {
2905  result.addAttribute(
2906  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2907  builder.getUnitAttr());
2908  }
2909  if (useAlloc) {
2910  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2911  builder.getUnitAttr());
2912  }
2913  auto resultType = transform::AnyOpType::get(ctx);
2914  result.addTypes({resultType, resultType, resultType, resultType});
2915 }
2916 
2917 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2918  transform::TransformRewriter &rewriter, LinalgOp target,
2920  transform::TransformState &state) {
2921  ControlSplitReductionFn splitFn = [&](LinalgOp) {
2922  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2923  unsigned(getInsertSplitDimension()),
2924  bool(getInnerParallel())};
2925  };
2926  rewriter.setInsertionPoint(target);
2927  FailureOr<SplitReductionResult> splitResult =
2928  (getUseScalingAlgorithm())
2929  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2930  : splitReduction(rewriter, target, splitFn, getUseAlloc());
2931  if (failed(splitResult))
2932  return emitDefaultDefiniteFailure(target);
2933 
2934  results.push_back(splitResult->initOrAlloc);
2935  results.push_back(splitResult->fillOp);
2936  results.push_back(splitResult->splitLinalgOp);
2937  results.push_back(splitResult->resultCombiningLinalgOp);
2939 }
2940 
2941 //===----------------------------------------------------------------------===//
2942 // TileReductionUsingForOp
2943 //===----------------------------------------------------------------------===//
2944 
2945 void transform::TileReductionUsingForOp::build(
2946  OpBuilder &builder, OperationState &result, Value target,
2947  ArrayRef<int64_t> staticTileSizes) {
2948  // Call the default builder.
2949  // This is future-proof re mixed static-dynamic and setting up the proper
2950  // operands segment sizes attributes for multiple variadic operands.
2951  // In the absence of this, horrible bugs ensue.
2952  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2953  MLIRContext *ctx = builder.getContext();
2954  auto opTy = transform::AnyOpType::get(ctx);
2955  auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
2956  build(builder, result,
2957  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2958  /*target=*/target,
2959  /*reduction_dims=*/nullptr,
2960  /*tile_sizes=*/staticTileSizesAttr);
2961 }
2962 
2963 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2964  transform::TransformRewriter &rewriter, Operation *target,
2966  transform::TransformState &state) {
2967  rewriter.setInsertionPoint(target);
2968 
2969  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2970  if (!partialReductionOp) {
2971  return emitSilenceableFailure(
2972  target->getLoc(),
2973  "Operation should implement PartialReductionOpInterface");
2974  }
2975 
2976  SmallVector<unsigned> reductionDims =
2977  extractFromIntegerArrayAttr<unsigned>(getReductionDims());
2978  if (reductionDims.empty()) {
2979  for (auto [idx, iteratorType] :
2980  llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
2981  if (iteratorType == utils::IteratorType::reduction)
2982  reductionDims.push_back(idx);
2983  }
2984  }
2985 
2988  options.setReductionTilingStrategy(
2990  options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
2991  options.setReductionDims(reductionDims);
2992  FailureOr<scf::SCFTilingResult> result =
2993  scf::tileUsingSCF(rewriter, partialReductionOp, options);
2994 
2995  if (failed(result)) {
2996  return emitSilenceableFailure(getLoc(),
2997  "failed to tile using partial reduction");
2998  }
2999  rewriter.replaceOp(target, result->replacements);
3000  for (Value initValue : result->initialValues)
3001  results.push_back(initValue.getDefiningOp());
3002  for (auto parallelTiledOp : result->tiledOps)
3003  results.push_back(parallelTiledOp);
3004  for (auto mergeOp : result->mergeOps)
3005  results.push_back(mergeOp);
3006  results.push_back(result->loops.front());
3008 }
3009 
3010 //===----------------------------------------------------------------------===//
3011 // TileReductionUsingForallOp
3012 //===----------------------------------------------------------------------===//
3013 
3014 void transform::TileReductionUsingForallOp::build(
3015  OpBuilder &builder, OperationState &result, Value target,
3016  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3017  ArrayAttr mapping) {
3018  // Call the default builder.
3019  // This is future-proof re mixed static-dynamic and setting up the proper
3020  // operands segment sizes attributes for multiple variadic operands.
3021  // In the absence of this, horrible bugs ensue.
3022  // TODO: support mixed static-dynamic (see TileUsingForallOp).
3023  MLIRContext *ctx = builder.getContext();
3024  auto opTy = transform::AnyOpType::get(ctx);
3025  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3026  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3027  build(builder, result,
3028  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3029  /*target=*/target,
3030  /*reduction_dims=*/{},
3031  /*num_threads=*/staticNumThreadsAttr,
3032  /*tile_sizes=*/staticTileSizesAttr,
3033  /*mapping=*/mapping);
3034 }
3035 
3036 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3037  transform::TransformRewriter &rewriter, Operation *target,
3039  transform::TransformState &state) {
3040  rewriter.setInsertionPoint(target);
3041 
3042  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3043  if (!partialReductionOp) {
3044  return emitSilenceableFailure(
3045  target->getLoc(),
3046  "Operation should implement PartialReductionOpInterface");
3047  }
3048  SmallVector<OpFoldResult> numThreads =
3049  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
3050  SmallVector<OpFoldResult> tileSizes =
3052 
3055  options.setReductionTilingStrategy(
3057  if (!getNumThreads().empty()) {
3058  options.setNumThreads(numThreads);
3059  } else {
3060  options.setTileSizes(tileSizes);
3061  }
3062  if (auto mapping = getMapping()) {
3063  options.setMapping(mapping.value().getValue());
3064  }
3065  SmallVector<unsigned> reductionDims =
3066  extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3067  if (reductionDims.empty()) {
3068  for (auto [idx, iteratorType] :
3069  llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3070  if (iteratorType == utils::IteratorType::reduction)
3071  reductionDims.push_back(idx);
3072  }
3073  }
3074  options.setReductionDims(reductionDims);
3075  FailureOr<scf::SCFTilingResult> result =
3076  scf::tileUsingSCF(rewriter, partialReductionOp, options);
3077 
3078  if (failed(result)) {
3079  auto diag = emitSilenceableError() << "could not tile reduction";
3080  return diag;
3081  }
3082  rewriter.replaceOp(target, result->replacements);
3083 
3084  for (Value initValue : result->initialValues)
3085  results.push_back(initValue.getDefiningOp());
3086  for (auto parallelTiledOp : result->tiledOps)
3087  results.push_back(parallelTiledOp);
3088  for (auto mergeOp : result->mergeOps)
3089  results.push_back(mergeOp);
3090  results.push_back(result->loops.front());
3092 }
3093 
3094 //===----------------------------------------------------------------------===//
3095 // ContinuousTileSizesOp
3096 //===----------------------------------------------------------------------===//
3097 
3099 transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3100  TransformResults &transformResults,
3101  TransformState &state) {
3102 
3103  SmallVector<Operation *> targetOps =
3104  llvm::to_vector(state.getPayloadOps(getTarget()));
3105 
3106  if (!llvm::hasSingleElement(targetOps)) {
3107  return mlir::emitSilenceableFailure(getLoc())
3108  << "requires exactly one target (got " << llvm::range_size(targetOps)
3109  << ")";
3110  }
3111 
3112  Operation *target = *targetOps.begin();
3113  auto linalgOp = dyn_cast<LinalgOp>(target);
3114  auto tileableOp = dyn_cast<TilingInterface>(target);
3115 
3116  if (!linalgOp)
3117  return emitDefiniteFailure() << "expected Linalg Op";
3118 
3119  OpBuilder builder(linalgOp.getContext());
3120 
3121  if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
3122  if (linalgOp.hasDynamicShape()) {
3123  auto diag = emitSilenceableError()
3124  << "cannot compute parametric tile sizes for dynamically "
3125  "shaped payload op";
3126  diag.attachNote(linalgOp->getLoc()) << "payload op";
3127  return diag;
3128  }
3129 
3130  FailureOr<StaticContinuousTileSizeSpecification> spec =
3131  computeStaticContinuousTileSizes(linalgOp, getDimension(),
3132  getTargetSize());
3133  if (failed(spec)) {
3134  return emitSilenceableError()
3135  << "failed to compute multi-size tiling sizes";
3136  }
3137 
3138  SmallVector<int64_t> chunkSizes;
3139 
3140  for (auto &&[tileSize, tripCount] :
3141  llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3142  chunkSizes.push_back(tileSize * tripCount);
3143 
3144  auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3145  return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3146  return builder.getI64IntegerAttr(value);
3147  });
3148  };
3149  transformResults.setParams(cast<OpResult>(getTileSizes()),
3150  getI64AttrsFromI64(spec->tileSizes));
3151  transformResults.setParams(cast<OpResult>(getChunkSizes()),
3152  getI64AttrsFromI64(chunkSizes));
3153 
3155  }
3156 
3157  builder.setInsertionPoint(linalgOp);
3158 
3159  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
3160  unsigned dimension = getDimension();
3161 
3162  FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
3163  builder, tileableOp, dimension, targetSize, true);
3164  if (failed(spec)) {
3165  return emitSilenceableError() << "could not generate tile size computation";
3166  }
3167 
3168  AffineExpr s0 = builder.getAffineSymbolExpr(0);
3169  AffineExpr s1 = builder.getAffineSymbolExpr(1);
3170  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3171  return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
3172  ofrs);
3173  };
3174 
3175  SmallVector<Value> chunkSizes;
3176  Value splitPoint;
3177  for (auto &&[tileSize, tripCount] :
3178  llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3179  splitPoint = apply(s0 * s1, {tileSize, tripCount});
3180  chunkSizes.push_back(splitPoint);
3181  }
3182 
3183  auto getDefiningOps = [&](ArrayRef<Value> values) {
3184  return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3185  return value.getDefiningOp();
3186  });
3187  };
3188 
3189  transformResults.set(cast<OpResult>(getTileSizes()),
3190  getDefiningOps(spec->tileSizes));
3191  transformResults.set(cast<OpResult>(getChunkSizes()),
3192  getDefiningOps(chunkSizes));
3193 
3195 }
3196 
3198 
3199  if (getTileSizes().getType() != getChunkSizes().getType()) {
3200  return emitOpError() << "expects all results type to be the same";
3201  }
3202 
3203  return success();
3204 }
3205 
3206 void transform::ContinuousTileSizesOp::getEffects(
3208  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
3209  onlyReadsPayload(effects);
3210  else
3211  modifiesPayload(effects);
3212  onlyReadsHandle(getTargetMutable(), effects);
3213  producesHandle(getOperation()->getOpResults(), effects);
3214 }
3215 
3217  Type targetType, Type tile_sizes,
3218  Type) {
3219  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
3220 }
3221 
3222 static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
3223  Type &targetType,
3224  Type &tileSizesType,
3225  Type &chunkSizesType) {
3226  FunctionType funcType;
3227  llvm::SMLoc typeLoc = parser.getCurrentLocation();
3228  if (failed(parser.parseType<FunctionType>(funcType)))
3229  return failure();
3230 
3231  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3232  parser.emitError(typeLoc) << "expects a trailing functional type with one "
3233  "argument and one result";
3234  }
3235  targetType = funcType.getInput(0);
3236  tileSizesType = chunkSizesType = funcType.getResult(0);
3237 
3238  return success();
3239 }
3240 
3241 //===----------------------------------------------------------------------===//
3242 // TileUsingForOp
3243 //===----------------------------------------------------------------------===//
3244 
3245 void transform::TileUsingForOp::build(
3246  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3247  Value target, ArrayRef<int64_t> staticTileSizes,
3248  ArrayRef<int64_t> interchange,
3249  std::optional<ArrayRef<bool>> scalableSizes) {
3250  return build(builder, result, loopTypes,
3251  /*target=*/target,
3252  /*mixedTileSizes=*/
3253  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3254  interchange, scalableSizes);
3255 }
3256 
3257 void transform::TileUsingForOp::build(
3258  OpBuilder &builder, OperationState &result, Value target,
3259  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3260  std::optional<ArrayRef<bool>> scalableSizes) {
3261  build(builder, result, target,
3262  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3263  interchange, scalableSizes);
3264 }
3265 
3266 void transform::TileUsingForOp::build(
3267  OpBuilder &builder, OperationState &result, Value target,
3268  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3269  std::optional<ArrayRef<bool>> scalableSizes) {
3270  // Loop types are automaticaly splat by the callee, setting up one is
3271  // enough.
3272  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
3273  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3274  scalableSizes);
3275 }
3276 
3277 void transform::TileUsingForOp::build(
3278  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3279  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
3280  ArrayRef<int64_t> interchange,
3281  std::optional<ArrayRef<bool>> scalableSizes) {
3282  SmallVector<int64_t> staticTileSizes;
3283  SmallVector<Value> dynamicTileSizes;
3284  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3285  // Call the default builder which sets up the proper operands segment sizes
3286  // attributes for multiple variadic operands. In the absence of this,
3287  // horrible bugs ensue.
3288  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3289  unsigned numExpectedLoops =
3290  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3291  SmallVector<Type> resultTypes;
3292  resultTypes.reserve(numExpectedLoops);
3293  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3294  "expected one loop type or as many as loops");
3295  if (loopTypes.size() == 1)
3296  resultTypes.append(numExpectedLoops, loopTypes[0]);
3297  else
3298  llvm::append_range(resultTypes, loopTypes);
3299  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
3300  if (scalableSizes.has_value())
3301  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3302  build(builder, result, /*tiled_linalg_op=*/target.getType(),
3303  /*loops=*/resultTypes,
3304  /*target=*/target,
3305  /*dynamic_sizes=*/dynamicTileSizes,
3306  /*static_sizes=*/staticTileSizesAttr,
3307  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
3308  /*scalable_sizes=*/expandedScalableSizes);
3309 }
3310 
3311 LogicalResult transform::TileUsingForOp::verify() {
3312  if (getMixedSizes().size() != getScalableSizes().size())
3313  return emitOpError("expected same number of sizes (")
3314  << getMixedSizes().size() << ") and scalable sizes ("
3315  << getScalableSizes().size() << ")";
3316  ArrayRef<int64_t> staticSizes = getStaticSizes();
3317  unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3318  if (getLoops().size() != numExpectedLoops)
3319  return emitOpError("expected number of loops to tile (")
3320  << numExpectedLoops << ") to match number of `loops` results ("
3321  << getLoops().size() << ")";
3322  return success();
3323 }
3324 
3326 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3327  TransformResults &transformResults,
3328  TransformState &state) {
3329  ArrayRef<int64_t> tileSizes = getStaticSizes();
3330 
3331  SmallVector<Operation *> targets =
3332  llvm::to_vector(state.getPayloadOps(getTarget()));
3333  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3335  dynamicSizeProducers.reserve(getDynamicSizes().size());
3336  paramSizes.reserve(getDynamicSizes().size());
3337  for (Value transformValue : getDynamicSizes()) {
3338  if (isa<ParamType>(transformValue.getType())) {
3339  dynamicSizeProducers.push_back({});
3340  ArrayRef<Attribute> params = state.getParams(transformValue);
3341  paramSizes.push_back(
3342  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
3343  return cast<IntegerAttr>(attr).getValue().getSExtValue();
3344  })));
3345 
3346  if (paramSizes.back().size() != targets.size()) {
3348  emitSilenceableError()
3349  << "expected as many parameter values ("
3350  << dynamicSizeProducers.back().size() << ") as target ops ("
3351  << targets.size() << ")";
3352  diag.attachNote(transformValue.getLoc()) << "for this parameter";
3353  return diag;
3354  }
3355 
3356  continue;
3357  }
3358  paramSizes.push_back({});
3359  dynamicSizeProducers.push_back(
3360  llvm::to_vector(state.getPayloadOps(transformValue)));
3361 
3362  if (dynamicSizeProducers.back().size() != targets.size()) {
3364  emitSilenceableError()
3365  << "expected as many dynamic size-producing operations ("
3366  << dynamicSizeProducers.back().size() << ") as target ops ("
3367  << targets.size() << ")";
3368  diag.attachNote(transformValue.getLoc()) << "for this handle";
3369  return diag;
3370  }
3371 
3372  for (Operation *op : dynamicSizeProducers.back()) {
3373  if (op->getNumResults() == 1 &&
3374  isa<IndexType>(op->getResult(0).getType())) {
3375  continue;
3376  }
3377 
3379  emitSilenceableError() << "expected sizes to be produced by ops "
3380  "with a single index-type result";
3381  diag.attachNote(op->getLoc()) << "size producer op";
3382  diag.attachNote(transformValue.getLoc()) << "for this handle";
3383  return diag;
3384  }
3385  }
3386 
3389  loops.resize(getLoops().size());
3390  auto scalableSizes = getScalableSizes();
3391  for (auto [i, op] : llvm::enumerate(targets)) {
3392  auto tilingInterface = dyn_cast<TilingInterface>(op);
3393  if (!tilingInterface) {
3395  emitSilenceableError()
3396  << "only ops implementing TilingInterface are supported";
3397  diag.attachNote(op->getLoc()) << "target op";
3398  return diag;
3399  }
3400  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3402  emitSilenceableError()
3403  << "too many tiles provided, expected at most "
3404  << tilingInterface.getLoopIteratorTypes().size() << " found "
3405  << tileSizes.size();
3406  diag.attachNote(op->getLoc()) << "target op";
3407  return diag;
3408  }
3409 
3410  scf::SCFTilingOptions tilingOptions;
3411  if (tileSizes.empty()) {
3412  tilingOptions.setTileSizeComputationFunction(
3414  return {};
3415  });
3416  } else {
3417  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3418  Operation *) {
3420  sizes.reserve(tileSizes.size());
3421  unsigned dynamicIdx = 0;
3422 
3423  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3424  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3425  if (scalableSizes[ofrIdx]) {
3426  auto val = arith::ConstantIndexOp::create(
3427  b, getLoc(), cast<IntegerAttr>(attr).getInt());
3428  Value vscale =
3429  vector::VectorScaleOp::create(b, getLoc(), b.getIndexType());
3430  sizes.push_back(
3431  arith::MulIOp::create(b, getLoc(), val, vscale).getResult());
3432  } else {
3433  sizes.push_back(attr);
3434  }
3435  continue;
3436  }
3437  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3438  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3439  ++dynamicIdx;
3440  assert((dynamicSizes.empty() ^ params.empty()) &&
3441  "expected either dynamic sizes or parameters");
3442  if (!params.empty()) {
3443  sizes.push_back(b.getIndexAttr(params[index]));
3444  } else {
3445  sizes.push_back(dynamicSizes[index]->getResult(0));
3446  }
3447  }
3448  return sizes;
3449  });
3450  }
3451 
3452  tilingOptions.setInterchange(getInterchange());
3453  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3454  tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3455  if (failed(maybeTilingResult))
3457 
3458  rewriter.replaceOp(op, maybeTilingResult->replacements);
3459 
3460  tiled.append(maybeTilingResult->tiledOps);
3461  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3462  loops[en2.index()].push_back(en2.value());
3463  }
3464 
3465  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3466  for (const auto &en : llvm::enumerate(loops))
3467  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3468 
3470 }
3471 
3473  ValueRange dynamic = getDynamicSizes();
3474  ArrayRef<int64_t> tileSizes = getStaticSizes();
3475  SmallVector<OpFoldResult> results;
3476  results.reserve(tileSizes.size());
3477  unsigned dynamicPos = 0;
3478  Builder builder(getContext());
3479  for (int64_t size : tileSizes) {
3480  if (size == ShapedType::kDynamic) {
3481  results.push_back(dynamic[dynamicPos++]);
3482  } else {
3483  results.push_back(builder.getIndexAttr(size));
3484  }
3485  }
3486  return results;
3487 }
3488 
3489 void transform::TileUsingForOp::getEffects(
3491  consumesHandle(getTargetMutable(), effects);
3492  onlyReadsHandle(getDynamicSizesMutable(), effects);
3493  producesHandle(getOperation()->getOpResults(), effects);
3494  modifiesPayload(effects);
3495 }
3496 
3497 //===----------------------------------------------------------------------===//
3498 // TileUsingForallOp
3499 //===----------------------------------------------------------------------===//
3500 
3501 void transform::TileUsingForallOp::build(OpBuilder &builder,
3502  OperationState &result, Value target,
3503  ArrayRef<int64_t> staticTileSizes,
3505  ArrayAttr mapping) {
3506  return build(builder, result,
3507  /*target=*/target,
3508  /*mixedTileSizes=*/
3509  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3510  /*_=*/TileSizesSpec(),
3511  /*mapping=*/mapping);
3512 }
3513 
3514 void transform::TileUsingForallOp::build(OpBuilder &builder,
3515  OperationState &result, Value target,
3516  ArrayRef<OpFoldResult> mixedTileSizes,
3518  ArrayAttr mapping) {
3519  SmallVector<int64_t> staticTileSizes;
3520  SmallVector<Value> dynamicTileSizes;
3521  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3522  // Call the default builder which sets up the proper operands segment sizes
3523  // attributes for multiple variadic operands. In the absence of this,
3524  // horrible bugs ensue.
3525  MLIRContext *ctx = builder.getContext();
3526  auto operationType = transform::AnyOpType::get(ctx);
3527  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3528  build(builder, result,
3529  /*resultTypes=*/TypeRange{operationType, operationType},
3530  /*target=*/target,
3531  /*num_threads=*/ValueRange{},
3532  /*tile_sizes=*/dynamicTileSizes,
3533  /*packed_num_threads=*/Value(),
3534  /*packed_tile_sizes=*/Value(),
3535  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3536  /*static_tile_sizes=*/staticTileSizesAttr,
3537  /*mapping=*/mapping);
3538 }
3539 
3540 void transform::TileUsingForallOp::build(OpBuilder &builder,
3541  OperationState &result, Value target,
3542  ArrayRef<int64_t> staticNumThreads,
3544  ArrayAttr mapping) {
3545  return build(builder, result, target,
3546  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3547  NumThreadsSpec(), mapping);
3548 }
3549 
3550 void transform::TileUsingForallOp::build(OpBuilder &builder,
3551  OperationState &result, Value target,
3552  ArrayRef<OpFoldResult> mixedNumThreads,
3554  ArrayAttr mapping) {
3555  SmallVector<int64_t> staticNumThreads;
3556  SmallVector<Value> dynamicNumThreads;
3557  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3558  staticNumThreads);
3559  // Call the default builder which sets up the proper operands segment sizes
3560  // attributes for multiple variadic operands. In the absence of this,
3561  // horrible bugs ensue.
3562  MLIRContext *ctx = builder.getContext();
3563  auto operationType = transform::AnyOpType::get(ctx);
3564  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3565  build(builder, result,
3566  /*resultTypes=*/TypeRange{operationType, operationType},
3567  /*target=*/target,
3568  /*num_threads=*/dynamicNumThreads,
3569  /*tile_sizes=*/ValueRange{},
3570  /*packed_num_threads=*/Value(),
3571  /*packed_tile_sizes=*/Value(),
3572  /*static_num_threads=*/staticNumThreadsAttr,
3573  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3574  /*mapping=*/mapping);
3575 }
3576 
3577 /// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3578 /// normalized upper bound.
3582  ArrayRef<OpFoldResult> steps) {
3583  AffineExpr s0, s1, s2;
3584  bindSymbols(rewriter.getContext(), s0, s1, s2);
3585  AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3586  SmallVector<OpFoldResult> normalizedUbs;
3587  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3589  rewriter, loc, normalizedUbExpr, {lb, ub, step});
3590  normalizedUbs.push_back(normalizedUb);
3591  }
3592  return normalizedUbs;
3593 }
3594 
3595 /// When a loop is normalized, the uses of the induction variable within the
3596 /// loop need to replaced with `original_lb + old_iv * original_step`.
3598  Location loc, ValueRange ivs,
3600  ArrayRef<OpFoldResult> steps) {
3601  AffineExpr s0, s1;
3602  AffineExpr d0;
3603  bindSymbols(rewriter.getContext(), s0, s1);
3604  bindDims(rewriter.getContext(), d0);
3605  AffineExpr denormExpr = s0 + d0 * s1;
3606  SmallVector<Value> denormalizedIvs;
3607 
3608  for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3610  rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3611  denormalizedIvs.push_back(
3612  getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3613  }
3614  return denormalizedIvs;
3615 }
3616 
3617 /// Given a `scf.forall` loop return a loop op with the loop bounds
3618 /// normalized.
3619 /// TODO: Replace this with a general utility to normalize `scf.forall`.
3620 /// At the time of writing, this wasnt done since adding this to `scf`
3621 /// dialect would disallow using of `affine.apply` operations due
3622 /// to cyclic dependencies. To avoid churn in lit tests
3623 /// with the change this was added with, defer that to a follow up.
3624 static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3625  scf::ForallOp loop) {
3626  SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3627  SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3628  SmallVector<OpFoldResult> steps = loop.getMixedStep();
3629 
3630  if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
3631  return loop;
3632  }
3633 
3634  Location loc = loop.getLoc();
3635  SmallVector<OpFoldResult> normalizedUbs =
3636  normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3637  SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3638  rewriter.getIndexAttr(0));
3639  SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3640  rewriter.getIndexAttr(1));
3641 
3642  auto normalizedForallOp = scf::ForallOp::create(
3643  rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3644  loop.getOutputs(), loop.getMapping(),
3645  [](OpBuilder &, Location, ValueRange) {});
3646 
3647  auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3648  OpBuilder::InsertionGuard g(rewriter);
3649  Block *normalizedLoopBlock = normalizedForallOp.getBody();
3650  rewriter.setInsertionPointToStart(normalizedLoopBlock);
3651 
3652  SmallVector<Value> argValues =
3653  denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3654  argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3655  normalizedForallOp.getRegionIterArgs().end());
3656  Block *origLoopBlock = loop.getBody();
3657  rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3658 
3659  rewriter.replaceOp(loop, normalizedForallOp);
3660  return normalizedForallOp;
3661 }
3662 
3664  RewriterBase &rewriter, transform::TransformState &state,
3665  TransformOpInterface transformOp, Operation *target,
3666  ArrayRef<OpFoldResult> mixedNumThreads,
3667  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3668  scf::SCFTilingResult &tilingResult) {
3669  // Transform all targets one by one.
3670  auto tileableOp = dyn_cast<TilingInterface>(target);
3671  if (!tileableOp) {
3673  transformOp.emitSilenceableError()
3674  << "only TilingInterface ops are supported";
3675  diag.attachNote(target->getLoc()) << "target op";
3676  return diag;
3677  }
3678  rewriter.setInsertionPoint(tileableOp);
3681  if (!mixedNumThreads.empty()) {
3682  options.setNumThreads(mixedNumThreads);
3683  } else {
3684  options.setTileSizes(mixedTileSizes);
3685  }
3686  if (mapping) {
3687  options.setMapping(mapping.value().getValue());
3688  }
3689  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3690  scf::tileUsingSCF(rewriter, tileableOp, options);
3691 
3692  if (failed(maybeTilingResult))
3693  return transformOp.emitDefaultSilenceableFailure(tileableOp);
3694 
3695  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3696 
3697  tilingResult = *maybeTilingResult;
3698 
3699  if (mixedNumThreads.empty()) {
3700  auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3701  OpBuilder::InsertionGuard g(rewriter);
3702  rewriter.setInsertionPoint(generatedForallOp);
3703  scf::ForallOp normalizedForallOp =
3704  normalizeForallLoopOp(rewriter, generatedForallOp);
3705  tilingResult.loops.front() = normalizedForallOp;
3706  }
3707 
3709 }
3710 
3711 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3712  transform::TransformRewriter &rewriter,
3713  transform::TransformResults &transformResults,
3714  transform::TransformState &state) {
3715  auto transformOp = cast<TransformOpInterface>(getOperation());
3716 
3717  // Result payload ops.
3718  SmallVector<Operation *> tileOps;
3719  SmallVector<Operation *> tiledOps;
3720 
3721  // Unpack handles.
3722  SmallVector<OpFoldResult> mixedNumThreads;
3724  getPackedNumThreads()
3726  state, transformOp, mixedNumThreads, getPackedNumThreads())
3728  state, transformOp, mixedNumThreads, getMixedNumThreads());
3729  if (!status.succeeded())
3730  return status;
3731  SmallVector<OpFoldResult> mixedTileSizes;
3732  status = getPackedTileSizes()
3734  state, transformOp, mixedTileSizes, getPackedTileSizes())
3736  state, transformOp, mixedTileSizes, getMixedTileSizes());
3737  if (!status.succeeded())
3738  return status;
3739 
3740  for (Operation *target : state.getPayloadOps(getTarget())) {
3741  scf::SCFTilingResult tilingResult;
3743  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3744  getMapping(), tilingResult);
3745  if (!diag.succeeded())
3746  return diag;
3747  tileOps.push_back(tilingResult.loops.front());
3748  tiledOps.append(tilingResult.tiledOps);
3749  }
3750 
3751  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3752  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3753 
3755 }
3756 
3757 void transform::TileUsingForallOp::getEffects(
3759  consumesHandle(getTargetMutable(), effects);
3760  onlyReadsHandle(getTileSizesMutable(), effects);
3761  onlyReadsHandle(getNumThreadsMutable(), effects);
3762  onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3763  onlyReadsHandle(getPackedTileSizesMutable(), effects);
3764  producesHandle(getOperation()->getOpResults(), effects);
3765  modifiesPayload(effects);
3766 }
3767 
3768 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3769  Builder b(getContext());
3770  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3771 }
3772 
3773 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3774  Builder b(getContext());
3775  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3776 }
3777 
3778 LogicalResult TileUsingForallOp::verify() {
3779  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3780  static_cast<int>(getPackedNumThreads() != Value());
3781  if (numThreadsSpec > 1)
3782  return emitOpError(
3783  "num_threads and packed_num_threads are mutually exclusive");
3784  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3785  static_cast<int>(getPackedTileSizes() != Value());
3786  if (tileSizesSpec > 1)
3787  return emitOpError(
3788  "tile_sizes and packed_tile_sizes are mutually exclusive");
3789  if (numThreadsSpec == 0 && tileSizesSpec == 0)
3790  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3791  "must be specified");
3792  return success();
3793 }
3794 
3795 //===----------------------------------------------------------------------===//
3796 // VectorizeChildrenAndApplyPatternsOp
3797 //===----------------------------------------------------------------------===//
3798 
3799 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3800  OpBuilder &builder, OperationState &result, Value target,
3801  bool foldTypeExtensionsIntoContract, bool vectorizePadding,
3802  bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3803  result.addOperands(target);
3804  if (foldTypeExtensionsIntoContract) {
3805  result.addAttribute(
3806  VectorizeChildrenAndApplyPatternsOp::
3807  getFoldTypeExtensionsIntoContractAttrName(result.name),
3808  builder.getUnitAttr());
3809  }
3810  if (vectorizePadding) {
3811  result.addAttribute(
3812  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3813  result.name),
3814  builder.getUnitAttr());
3815  }
3816  if (vectorizeExtract) {
3817  result.addAttribute(
3818  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3819  result.name),
3820  builder.getUnitAttr());
3821  }
3822  if (flatten1DDepthwiseConv) {
3823  result.addAttribute(
3824  VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3825  result.name),
3826  builder.getUnitAttr());
3827  }
3828  result.addTypes(transform::AnyOpType::get(builder.getContext()));
3829 }
3830 
3831 namespace {
3832 /// This is an helper only to call vectorize via a pattern inside of
3833 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3834 struct VectorizationPattern : public RewritePattern {
3835  explicit VectorizationPattern(MLIRContext *context,
3836  bool vectorizeExtract = false,
3837  bool flattenConv = false)
3838  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3839  vectorizeNDExtract(vectorizeExtract),
3840  flatten1DDepthwiseConv(flattenConv) {}
3841  LogicalResult matchAndRewrite(Operation *op,
3842  PatternRewriter &rewriter) const override {
3844  return rewriter.notifyMatchFailure(op,
3845  "Unsupported Op, cannot vectorize");
3846  FailureOr<VectorizationResult> vectorResults =
3847  vectorize(rewriter, op, /*inputVectorSizes=*/{},
3848  /*inputScalableVecDims=*/{}, vectorizeNDExtract,
3849  flatten1DDepthwiseConv);
3850  if (failed(vectorResults))
3851  return failure();
3852  rewriter.replaceOp(op, vectorResults->replacements);
3853  return success();
3854  }
3855 
3856 private:
3857  /// Controls whether to vectorize `tensor.extract` when the input tensor is
3858  /// rank >= 2.
3859  bool vectorizeNDExtract = false;
3860  /// Controls whether to "flatten" the channel dimension when vectorising 1D
3861  /// depthwise convolutions. This should lead to bette vectorization for
3862  /// tensors with a low number of channel dimensions.
3863  bool flatten1DDepthwiseConv = false;
3864 };
3865 } // namespace
3866 
3868 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3869  transform::TransformRewriter &rewriter, Operation *target,
3871  transform::TransformState &state) {
3872  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3873  auto diag = this->emitOpError("requires isolated-from-above targets");
3874  diag.attachNote(target->getLoc()) << "non-isolated target";
3876  }
3877 
3878  MLIRContext *ctx = getContext();
3880  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3881  getFlatten_1dDepthwiseConv());
3882 
3883  if (!getDisableTransferPermutationMapLoweringPatterns())
3885 
3886  if (!getDisableMultiReductionToContractPatterns())
3888 
3890 
3893  /*benefit=*/2);
3894  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3895  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3897 
3899 
3900  if (getFoldTypeExtensionsIntoContract())
3902 
3903  if (getVectorizePadding()) {
3905  // This creates an alternative path for lowering tensor.pad - by
3906  // decomposing it into e.g. linalg.fill.
3908  }
3910 
3911  TrackingListener listener(state, *this);
3912  if (failed(
3913  applyPatternsGreedily(target, std::move(patterns),
3914  GreedyRewriteConfig().setListener(&listener))))
3915  return emitDefaultDefiniteFailure(target);
3916 
3917  results.push_back(target);
3919 }
3920 
3921 //===----------------------------------------------------------------------===//
3922 // VectorizeOp
3923 //===----------------------------------------------------------------------===//
3924 
3925 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3926  transform::TransformRewriter &rewriter,
3927  mlir::transform::TransformResults &transformResults,
3929  auto targets = state.getPayloadOps(getTarget());
3930  if (std::empty(targets))
3932  auto transformOp = cast<TransformOpInterface>(getOperation());
3933  SmallVector<int64_t> vectorSizes;
3935  state, transformOp, getMixedVectorSizes(), vectorSizes);
3936  if (!status.succeeded())
3937  return status;
3938 
3939  // TODO: Check that the correct number of vectorSizes was provided.
3940  for (Operation *target : targets) {
3941  if (!linalg::hasVectorizationImpl(target)) {
3942  return mlir::emitSilenceableFailure(target->getLoc())
3943  << "Unsupported Op, cannot vectorize";
3944  }
3945  FailureOr<VectorizationResult> vectorResults =
3946  linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
3947  getVectorizeNdExtract().value_or(false),
3948  /*flatten1DDepthwiseConv=*/false,
3949  getAssumeDynamicDimsMatchVecSizes().value_or(false),
3950  getCreateNamedContraction().value_or(false));
3951  if (failed(vectorResults)) {
3952  return mlir::emitSilenceableFailure(target->getLoc())
3953  << "Attempted to vectorize, but failed";
3954  }
3955  rewriter.replaceOp(target, vectorResults->replacements);
3956  }
3957 
3959 }
3960 
3961 void transform::VectorizeOp::getEffects(
3963  consumesHandle(getTargetMutable(), effects);
3964  onlyReadsHandle(getVectorSizesMutable(), effects);
3965  modifiesPayload(effects);
3966 }
3967 
3968 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3969  OpBuilder b(getContext());
3970  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3971 }
3972 
3973 LogicalResult transform::VectorizeOp::verify() {
3974  if (getStaticVectorSizes().size() != getScalableSizes().size())
3975  return emitOpError("expected same number of vector sizes (")
3976  << getStaticVectorSizes().size() << ") and scalable sizes ("
3977  << getScalableSizes().size() << ")";
3978  return success();
3979 }
3980 
3981 //===----------------------------------------------------------------------===//
3982 // HoistRedundantVectorTransfersOp
3983 //===----------------------------------------------------------------------===//
3984 
3986 transform::HoistRedundantVectorTransfersOp::applyToOne(
3987  transform::TransformRewriter &rewriter, func::FuncOp target,
3989  transform::TransformState &state) {
3990  // WARNING: This hoisting does not model parallelism and is generally
3991  // incorrect when used on distributed loops with memref semantics!
3992  // TODO: obsolete and should be retired.
3993  linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
3994  results.push_back(target);
3996 }
3997 
3998 //===----------------------------------------------------------------------===//
3999 // HoistRedundantVectorBroadcastsOp
4000 //===----------------------------------------------------------------------===//
4001 
4003 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4004  transform::TransformRewriter &rewriter, mlir::Operation *target,
4006  transform::TransformState &state) {
4007  rewriter.setInsertionPoint(target);
4008  linalg::hoistRedundantVectorBroadcasts(rewriter, target);
4009  results.push_back(target);
4011 }
4012 
4013 //===----------------------------------------------------------------------===//
4014 // ConvertConv2DToImg2ColOp.
4015 //===----------------------------------------------------------------------===//
4016 
4017 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
4018  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4020  transform::TransformState &state) {
4021  rewriter.setInsertionPoint(target);
4022  auto maybeTransformed =
4024  target)
4025  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4026  return rewriteInIm2Col(rewriter, op);
4027  })
4028  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4029  return rewriteInIm2Col(rewriter, op);
4030  })
4031  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4032  return rewriteInIm2Col(rewriter, op);
4033  })
4034  .Case([&](linalg::Conv2DNchwFchwOp op) {
4035  return rewriteInIm2Col(rewriter, op);
4036  })
4037  .Default([&](Operation *op) {
4038  return rewriter.notifyMatchFailure(op, "not supported");
4039  });
4040  if (failed(maybeTransformed))
4041  return emitDefaultSilenceableFailure(target);
4042  // Handle to the operation producing the img2col tensor.
4043  results.push_back(maybeTransformed->first);
4044  // Handle to the operation that replaces the original convolution.
4045  results.push_back(maybeTransformed->second);
4047 }
4048 
4049 //===----------------------------------------------------------------------===//
4050 // FlattenElementwiseLinalgOp.
4051 //===----------------------------------------------------------------------===//
4052 
4053 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4054  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4056  transform::TransformState &state) {
4057  rewriter.setInsertionPoint(target);
4058  if (!isElementwise(target))
4059  return mlir::emitSilenceableFailure(target->getLoc())
4060  << "only elementwise flattening is supported";
4061 
4062  // If rank <= 1, do nothing
4063  if (target.getNumLoops() <= 1) {
4064  results.push_back(target);
4066  }
4067 
4068  // Attempt to flatten all dims to one.
4069  ReassociationIndices reassociation(target.getNumLoops());
4070  std::iota(reassociation.begin(), reassociation.end(), 0);
4071  auto maybeFlattened =
4072  collapseOpIterationDims(target, reassociation, rewriter);
4073  if (failed(maybeFlattened))
4074  return mlir::emitSilenceableFailure(target->getLoc())
4075  << "attempted to flatten, but failed";
4076  results.push_back(maybeFlattened->collapsedOp);
4077  rewriter.replaceOp(target, maybeFlattened->results);
4079 }
4080 
4081 //===----------------------------------------------------------------------===//
4082 // TransposeConv2DOp
4083 //===----------------------------------------------------------------------===//
4084 
4085 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4086  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4088  transform::TransformState &state) {
4089  rewriter.setInsertionPoint(target);
4090  auto maybeTransformed =
4092  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4093  return transposeConv2D(rewriter, op);
4094  })
4095  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4096  return transposeConv2D(rewriter, op);
4097  })
4098  .Default([&](Operation *op) {
4099  return rewriter.notifyMatchFailure(op, "not supported");
4100  });
4101  if (failed(maybeTransformed))
4102  return emitDefaultSilenceableFailure(target);
4103  // Handle to the new Conv2D operation with transposed filters
4104  results.push_back(*maybeTransformed);
4106 }
4107 
4108 //===----------------------------------------------------------------------===//
4109 // TransposeMatmulOp
4110 //===----------------------------------------------------------------------===//
4111 
4112 DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4113  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4115  transform::TransformState &state) {
4116  rewriter.setInsertionPoint(target);
4117  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4118  auto maybeTransformed =
4120  .Case([&](linalg::MatmulOp op) {
4121  return transposeMatmul(rewriter, op, transposeLHS);
4122  })
4123  .Case([&](linalg::BatchMatmulOp op) {
4124  return transposeBatchMatmul(rewriter, op, transposeLHS);
4125  })
4126  .Default([&](Operation *op) { return failure(); });
4127  if (failed(maybeTransformed))
4128  return emitSilenceableFailure(target->getLoc()) << "not supported";
4129  // Handle to the new Matmul operation with transposed filters
4130  results.push_back(*maybeTransformed);
4132 }
4133 
4134 //===----------------------------------------------------------------------===//
4135 // InsertSliceToCopyOp
4136 //===----------------------------------------------------------------------===//
4137 template <typename OpTy>
4140  transform::TransformState &state) {
4141  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4142  tensor::ParallelInsertSliceOp>() &&
4143  "wrong op type");
4144 
4145  if (auto copySource =
4146  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4147  results.push_back(copySource);
4149  }
4150 
4151  // If we are inside a `ParallelCombiningOp` region, temporarily set the
4152  // insertion point outside: only ops implementing ParallelCombiningOpInterface
4153  // are allowed in there.
4154  if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
4155  rewriter.setInsertionPoint(target->getParentOp());
4156 
4157  Value extracted = tensor::ExtractSliceOp::create(
4158  rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
4159  target.getMixedSizes(), target.getMixedStrides());
4160  Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
4161  target.getSource(), extracted)
4162  .getResult(0);
4163  // Reset the insertion point.
4164  rewriter.setInsertionPoint(target);
4165  rewriter.replaceOpWithNewOp<OpTy>(
4166  target, copied, target.getDest(), target.getMixedOffsets(),
4167  target.getMixedSizes(), target.getMixedStrides());
4168 
4169  results.push_back(copied.getDefiningOp());
4171 }
4172 
4173 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4174  transform::TransformRewriter &rewriter, Operation *targetOp,
4176  transform::TransformState &state) {
4177 
4178  rewriter.setInsertionPoint(targetOp);
4179  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4180  return doit(rewriter, target, results, state);
4181  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4182  return doit(rewriter, target, results, state);
4183 
4185  emitSilenceableError()
4186  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4187  diag.attachNote(targetOp->getLoc()) << "target op";
4188  return diag;
4189 }
4190 
4191 //===----------------------------------------------------------------------===//
4192 // MapCopyToThreadsOp
4193 //===----------------------------------------------------------------------===//
4194 
4195 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4196  transform::TransformRewriter &rewriter, Operation *target,
4198  transform::TransformState &state) {
4199  // Check if the op is supported.
4200  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4202  emitSilenceableError()
4203  << "only linalg.copy and tensor.pad target ops are supported";
4204  diag.attachNote(target->getLoc()) << "target op";
4205  return diag;
4206  }
4207  assert(target->getNumResults() == 1 && "expected single result");
4208  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
4209  if (!resultShapedType.hasStaticShape()) {
4211  emitSilenceableError()
4212  << "only statically sized ops of rank <= 3 are supported";
4213  diag.attachNote(target->getLoc()) << "target op";
4214  return diag;
4215  }
4216 
4217  // Conservatively set the minimum viable desired bitwidth alignment.
4218  int64_t desiredBitAlignment = getDesiredBitAlignment();
4219  int64_t eltBitwidth =
4220  resultShapedType.getElementType().getIntOrFloatBitWidth();
4221  if (desiredBitAlignment % eltBitwidth != 0) {
4222  desiredBitAlignment = eltBitwidth;
4223  }
4224 
4225  gpu::CopyMappingInfo mapping(
4226  /*ctx=*/getContext(),
4227  /*totalNumThreads=*/getTotalNumThreads(),
4228  /*alignment=*/desiredBitAlignment,
4229  /*sizes=*/resultShapedType.getShape(),
4230  /*favorPredication=*/false,
4231  /*elementalBitwidth=*/
4232  resultShapedType.getElementType().getIntOrFloatBitWidth());
4233  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4235  emitSilenceableError()
4236  << "too few threads to map copy op to threads on the most minor "
4237  "dimension, given alignment and vector size constraints, try "
4238  "smaller tile size of mapping to more threads";
4239  diag.attachNote(target->getLoc()) << "target op";
4240  return diag;
4241  }
4242 
4243  // OpBuilder only used to compute attributes.
4244  OpBuilder b(getContext());
4245  scf::SCFTilingResult tilingResult;
4247  /*rewriter=*/rewriter,
4248  /*state=*/state,
4249  /*transformOp=*/*this,
4250  /*target=*/target,
4251  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
4252  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
4253  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
4254  /*tilingResult=*/tilingResult);
4255  if (!diag.succeeded())
4256  return diag;
4257 
4258  results.push_back(tilingResult.loops.front());
4259  for (auto op : tilingResult.tiledOps)
4260  results.push_back(op);
4262 }
4263 
4264 //===----------------------------------------------------------------------===//
4265 // WinogradConv2DOp
4266 //===----------------------------------------------------------------------===//
4267 
4268 DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4269  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4271  transform::TransformState &state) {
4272  rewriter.setInsertionPoint(target);
4273  FailureOr<Operation *> maybeTransformed = failure();
4274  bool supported = TypeSwitch<Operation *, bool>(target)
4275  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4276  maybeTransformed =
4277  winogradConv2D(rewriter, op, getFmr());
4278  return true;
4279  })
4280  .Default([&](Operation *op) { return false; });
4281 
4282  if (!supported) {
4283  return emitSilenceableError()
4284  << "this operation is not supported to convert to Winograd Conv2D";
4285  }
4286 
4287  if (failed(maybeTransformed)) {
4288  return emitSilenceableError() << "apply Winograd Conv2D failed";
4289  }
4290 
4291  results.push_back(*maybeTransformed);
4293 }
4294 
4295 DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4296  transform::TransformRewriter &rewriter, Operation *target,
4298  transform::TransformState &state) {
4299  rewriter.setInsertionPoint(target);
4300  FailureOr<Operation *> maybeTransformed = failure();
4301  bool supported =
4303  .Case([&](linalg::WinogradFilterTransformOp op) {
4304  maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
4305  return true;
4306  })
4307  .Case([&](linalg::WinogradInputTransformOp op) {
4308  maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
4309  return true;
4310  })
4311  .Case([&](linalg::WinogradOutputTransformOp op) {
4312  maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
4313  return true;
4314  })
4315  .Default([&](Operation *op) { return false; });
4316 
4317  if (!supported) {
4319  emitSilenceableError()
4320  << "this operation is not supported to decompose into other operations";
4321  diag.attachNote(target->getLoc()) << "target op";
4322  return diag;
4323  }
4324 
4325  if (failed(maybeTransformed)) {
4327  emitSilenceableError() << "decompose Winograd operations failed";
4328  diag.attachNote(target->getLoc()) << "target op";
4329  return diag;
4330  }
4331 
4332  results.push_back(*maybeTransformed);
4334 }
4335 
4336 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4337 
4338 #define GET_OP_CLASSES
4339 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Definition: AMXDialect.cpp:70
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static bool sameOrEquivalentIterArg(Value src, Value dst)
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:166
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:367
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:280
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:305
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
This class represents a saved insertion point.
Definition: Builders.h:327
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:337
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:316
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getOpResult(unsigned idx)
Definition: Operation.h:421
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1276
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1374
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
Definition: Padding.cpp:244
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:471
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:657
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:245
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:346
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:857
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:512
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:496
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:487
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:400
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:504
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:220
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:45
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:236
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition: Tiling.cpp:156
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition: Tiling.cpp:106
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:89
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:748
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:464
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:422
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:491
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:480
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:198
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:262
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:217
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:651
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:61
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:114
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:285
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:431
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:432
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:522
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1626
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1558
Match and rewrite for the pattern:
Definition: Transforms.h:1748
Match and rewrite for the pattern:
Definition: Transforms.h:1776
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:421
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:434
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:447
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:467
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:410
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:441
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:457
LinalgPromotionOptions & setUseOriginalSubviewSize(bool originalSize)
Definition: Transforms.h:428
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:399
Split Reduction options.
Definition: Transforms.h:476
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.