MLIR  22.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/TypeUtilities.h"
40 #include "mlir/Support/LLVM.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/ScopeExit.h"
44 #include "llvm/ADT/TypeSwitch.h"
45 #include "llvm/Support/DebugLog.h"
46 #include "llvm/Support/LogicalResult.h"
47 #include <type_traits>
48 
49 using namespace mlir;
50 using namespace mlir::linalg;
51 using namespace mlir::transform;
52 
53 #define DEBUG_TYPE "linalg-transforms"
54 
55 /// Attempts to apply the pattern specified as template argument to the given
56 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
57 /// function that returns the "main" result or failure. Returns failure if the
58 /// pattern failed to apply. Extra arguments are forwarded to the pattern
59 /// constructor.
60 template <typename PatternTy, typename... Args>
61 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
62  // Check if the given operation has the type expected by the pattern.
63  using OpTy = typename llvm::function_traits<
64  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
65  auto op = dyn_cast<OpTy>(operation);
66  if (!op)
67  return failure();
68 
69  // Apply the pattern directly to the op.
70  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
71  // We want to discourage direct use of PatternRewriter in APIs but In this
72  // very specific case, an IRRewriter is not enough.
73  PatternRewriter rewriter(operation->getContext());
74  rewriter.setInsertionPoint(operation);
75  auto result = pattern.returningMatchAndRewrite(op, rewriter);
76  if (failed(result))
77  return failure();
78  return cast<LinalgOp>(result->getOperation());
79 }
80 
81 /// Assuming that `ofr` is an index attr or a param of index type
82 /// or a transform dialect handle mapped to exactly one op
83 /// with one index result, return that value.
85  transform::TransformState &state, TransformOpInterface transformOp,
87  for (OpFoldResult ofr : ofrs) {
88  if (auto attr = dyn_cast<Attribute>(ofr)) {
89  if (!isa<IntegerAttr>(attr))
90  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
91  result.push_back(ofr);
92  continue;
93  }
94 
95  Value transformValue = cast<Value>(ofr);
96  if (isa<TransformParamTypeInterface>(transformValue.getType())) {
97  ArrayRef<Attribute> params = state.getParams(transformValue);
98  if (params.size() != 1)
99  return transformOp.emitDefiniteFailure()
100  << "requires exactly one parameter associated";
101  result.push_back(params[0]);
102  continue;
103  }
104 
105  auto payloadOps = state.getPayloadOps(transformValue);
106  if (!llvm::hasSingleElement(payloadOps)) {
108  transformOp.emitSilenceableError()
109  << "handle must be mapped to exactly one payload op";
110  diag.attachNote(transformValue.getLoc())
111  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
112  return diag;
113  }
114 
115  Operation *op = *payloadOps.begin();
116  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
118  transformOp.emitSilenceableError()
119  << "payload op must have exactly 1 index result";
120  diag.attachNote(op->getLoc())
121  << "has " << op->getNumResults() << " results";
122  return diag;
123  }
124  result.push_back(op->getResult(0));
125  }
126 
128 }
129 
130 // Given a list of params that are index attrs or a list of OpFoldResults
131 // that are either index attrs or op handles, return a list of OpFoldResults
132 // of index attrs or a list of OpFoldResults where all op handles are
133 // replaced with the first (and only) OpResult of that payload op.
134 // (There must be exactly one parameter associated with the AnyParamType or
135 // one mapped payload op which must have exactly one index result.)
137  transform::TransformState &state, TransformOpInterface transformOp,
138  SmallVector<OpFoldResult> &result, Value packedHandle) {
139  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
140  ArrayRef<Attribute> params = state.getParams(packedHandle);
141  for (auto param : params) {
142  if (!isa<IntegerAttr>(param))
143  return transformOp.emitDefiniteFailure()
144  << "expected the parameter to be associated with an integer "
145  "attribute";
146  result.push_back(param);
147  }
149  }
150 
151  for (Operation *op : state.getPayloadOps(packedHandle)) {
152  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
154  transformOp.emitSilenceableError()
155  << "payload op must have exactly 1 index result";
156  diag.attachNote(op->getLoc())
157  << "has " << op->getNumResults() << " results";
158  return diag;
159  }
160  result.push_back(op->getResult(0));
161  }
162 
164 }
165 
166 /// When possible, converts each `OpFoldResult` in `mixedResult` to
167 /// an integer if the value can be statically inferred. If a result
168 /// is a `Value` then it must be either a `ParamType` or a handle
169 /// to an a constant like op.
171  TransformState &state, TransformOpInterface &transformOp,
172  ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
173  for (OpFoldResult paramOrHandle : mixedResults) {
174  if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
175  reified.push_back(cast<IntegerAttr>(attr).getInt());
176  continue;
177  } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
178  ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
179  if (params.size() != 1)
180  return transformOp.emitSilenceableError() << "expected a single param";
181  reified.push_back(
182  cast<IntegerAttr>(params.front()).getValue().getSExtValue());
183  continue;
184  }
185 
186  Value handle = cast<Value>(paramOrHandle);
187  if (!isa<TransformHandleTypeInterface>(handle.getType()))
188  return transformOp.emitSilenceableError() << "unexpected value handle";
189  auto payload = state.getPayloadOps(handle);
190  if (!llvm::hasSingleElement(payload))
191  return transformOp.emitSilenceableError()
192  << "requires param or handle that is mapped to 1 payload op";
193 
194  Operation *paramOrHandlePayloadOp = *payload.begin();
195  if (paramOrHandlePayloadOp->getNumResults() != 1 ||
196  !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
197  return transformOp.emitSilenceableError()
198  << "requires param or handle to be result of op with 1 index "
199  "result";
200  }
201 
202  IntegerAttr attr;
203  if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
204  return transformOp.emitSilenceableError()
205  << "requires param or handle to be the result of a constant like "
206  "op";
207 
208  reified.push_back(attr.getInt());
209  }
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // Apply...PatternsOp
215 //===----------------------------------------------------------------------===//
216 
217 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
220 }
221 
222 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
225 }
226 
227 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
230 }
231 
232 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
236 }
237 
238 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
241  options.rankReductionStrategy =
244 }
245 
246 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
249 }
250 
251 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
254 }
255 
256 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
259 }
260 
261 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
264 }
265 
266 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // BufferizeToAllocationOp
273 //===----------------------------------------------------------------------===//
274 
275 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
276  OperationState &result,
277  Value target,
278  Attribute memorySpace) {
279  SmallVector<Type> resultTypes;
280  resultTypes.push_back(b.getType<transform::AnyValueType>());
281  resultTypes.push_back(b.getType<transform::AnyOpType>());
282  return build(b, result,
283  /*resultTypes=*/resultTypes,
284  /*target=*/target,
285  /*memorySpace=*/memorySpace);
286 }
287 
288 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
289  OperationState &result,
290  Value target,
291  int64_t memorySpace) {
292  SmallVector<Type> resultTypes;
293  resultTypes.push_back(b.getType<transform::AnyValueType>());
294  resultTypes.push_back(b.getType<transform::AnyOpType>());
295  return build(b, result,
296  /*resultTypes=*/resultTypes,
297  /*target=*/target,
298  /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
299 }
300 
301 namespace {
302 class NewOpsListener : public RewriterBase::ForwardingListener {
303 public:
305 
306  SmallVector<Operation *> getNewOps() const {
307  return SmallVector<Operation *>(newOps.begin(), newOps.end());
308  }
309 
310 private:
311  void notifyOperationInserted(Operation *op,
312  OpBuilder::InsertPoint previous) override {
313  ForwardingListener::notifyOperationInserted(op, previous);
314  // We only care about newly created ops.
315  if (previous.isSet())
316  return;
317  auto inserted = newOps.insert(op);
318  (void)inserted;
319  assert(inserted.second && "expected newly created op");
320  }
321 
322  void notifyOperationErased(Operation *op) override {
323  ForwardingListener::notifyOperationErased(op);
324  op->walk([&](Operation *op) { newOps.erase(op); });
325  }
326 
327  DenseSet<Operation *> newOps;
328 };
329 } // namespace
330 
331 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
334  // Attach listener to keep track of newly created ops.
335  OpBuilder::Listener *previousListener = rewriter.getListener();
336  auto resetListener =
337  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
338  NewOpsListener newOpsListener(previousListener);
339  rewriter.setListener(&newOpsListener);
340 
342  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
345  } else if (getMemcpyOp() == "memref.copy") {
346  options.memcpyOp =
348  } else if (getMemcpyOp() == "linalg.copy") {
349  options.memcpyOp =
351  } else {
352  llvm_unreachable("invalid memcpy op");
353  }
354  if (getAllocOp() == "memref.alloc") {
355  options.allocOp =
357  } else if (getAllocOp() == "memref.alloca") {
358  options.allocOp =
360  } else {
361  llvm_unreachable("invalid alloc op");
362  }
363  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
364  options.emitDealloc = getEmitDealloc();
365 
366  // Bufferize ops.
367  Attribute memorySpace =
368  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
369  SmallVector<Value> allocatedBuffers;
370  for (Operation *op : state.getPayloadOps(getTarget())) {
371  Value buffer =
372  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
373  if (!buffer) {
374  DiagnosedSilenceableFailure diag = emitSilenceableError()
375  << "failed to bufferize operation";
376  diag.attachNote(op->getLoc()) << "target payload op";
377  return diag;
378  }
379  allocatedBuffers.push_back(buffer);
380  }
381 
382  // Set results.
383  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
384  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
386 }
387 
388 void transform::BufferizeToAllocationOp::getEffects(
390  if (getBufferizeDestinationOnly()) {
391  // The destination is replaced with a newly allocated buffer, but the op
392  // itself remains in place.
393  onlyReadsHandle(getTargetMutable(), effects);
394  } else {
395  consumesHandle(getTargetMutable(), effects);
396  }
397  producesHandle(getOperation()->getOpResults(), effects);
398  modifiesPayload(effects);
399 }
400 
402  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
403  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
404  return emitOpError() << "unsupported memcpy op";
405  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
406  return emitOpError() << "unsupported alloc op";
407  return success();
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // DecomposeOp
412 //===----------------------------------------------------------------------===//
413 
415 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
416  LinalgOp target,
418  transform::TransformState &state) {
419 #define DOWNSCALE(trans) \
420  { \
421  FailureOr<LinalgOp> res = tryApply<trans>(target); \
422  if (succeeded(res)) { \
423  results.push_back(*res); \
424  return DiagnosedSilenceableFailure::success(); \
425  } \
426  }
427 
428 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
429 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
430 
431  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
432  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
433  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
434  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
435  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
436  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
437  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
438  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
439  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
442 #undef DOWNSCALE_NORMAL
443 #undef DOWNSCALE_CALL
444 #undef DOWNSCALE
445  return emitDefaultSilenceableFailure(target);
446 }
447 
448 //===----------------------------------------------------------------------===//
449 // DecomposeInterfaceOp
450 //===----------------------------------------------------------------------===//
451 
452 // Decompose the target operation if it implements the AggregatedOpInterface.
453 // Push the decomposed operations (the ones that replaces the values produced by
454 // \p target) in the `results`.
455 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
456  transform::TransformRewriter &rewriter, Operation *target,
458  transform::TransformState &state) {
459  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
460  if (!decomposableOp) {
461  failed(rewriter.notifyMatchFailure(target,
462  "payload is not a decomposable op"));
463  return emitDefaultSilenceableFailure(target);
464  }
465 
466  FailureOr<SmallVector<Value>> maybeNewResults =
467  decomposableOp.decomposeOperation(rewriter);
468  if (failed(maybeNewResults))
469  return emitDefaultSilenceableFailure(target);
470 
471  rewriter.replaceOp(decomposableOp, *maybeNewResults);
472  for (Value val : *maybeNewResults) {
473  Operation *definition = val.getDefiningOp();
474  if (definition)
475  results.push_back(definition);
476  }
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // EliminateLinalgOpAnchoredEmptyTensorsOp
482 //===----------------------------------------------------------------------===//
483 
484 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
486  onlyReadsHandle(getTargetMutable(), effects);
487  modifiesPayload(effects);
488 }
489 
491 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
492  transform::TransformRewriter &rewriter, TransformResults &transformResults,
493  TransformState &state) {
495  options.allowReturnAllocsFromLoops = true;
496 
497  for (Operation *target : state.getPayloadOps(getTarget())) {
499  if (failed(analyzeOp(target, state)))
500  return mlir::emitSilenceableFailure(target->getLoc())
501  << "failed to analyze op";
503  rewriter, target, state)))
504  return mlir::emitSilenceableFailure(target->getLoc())
505  << "failed to eliminate LinalgOp anchored tensor.empty ops";
506  }
508 }
509 
510 //===----------------------------------------------------------------------===//
511 // FuseOp
512 //===----------------------------------------------------------------------===//
513 
514 /// Apply a tiling transformation to all payload ops and store both the
515 /// tiled operation as well as the created tile loops.
516 template <typename Range>
517 static LogicalResult applyTilingToAll(
518  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
519  unsigned numLoops, transform::TransformResults &transformResults,
520  function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
521  applyFn) {
522  SmallVector<Operation *> tiledLinalgOps;
523  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
524 
525  for (Operation *target : payloadOps) {
526  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
527  if (!tilingInterfaceOp)
528  return transformOp->emitError("only TilingInterface ops are supported");
529 
530  rewriter.setInsertionPoint(target);
531  FailureOr<scf::SCFTileAndFuseResult> tiledResults =
532  applyFn(tilingInterfaceOp);
533  if (failed(tiledResults))
534  return failure();
535 
536  // Perform the replacement of tiled and fused values.
537  SmallVector<Operation *> opsToReplace{target};
538  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
539  for (Operation *toReplace : opsToReplace) {
540  for (OpResult res : toReplace->getResults())
541  if (auto replacement = tiledResults->replacements.lookup(res))
542  rewriter.replaceAllUsesWith(res, replacement);
543  if (toReplace->use_empty()) {
544  rewriter.eraseOp(toReplace);
545  }
546  }
547 
548  // Report back the relevant handles to the transform op.
549  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
550  assert(tiledResults->loops.size() == numLoops &&
551  "Mismatched number of loops, tile and fuse transform should have "
552  "failed");
553  for (unsigned int i = 0; i < numLoops; ++i)
554  loopOps[i].push_back(tiledResults->loops[i]);
555  }
556 
557  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
558  for (unsigned int i = 0; i < numLoops; ++i)
559  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
560 
561  return success();
562 }
563 
565 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
566  mlir::transform::TransformResults &transformResults,
568  SmallVector<int64_t> tileSizes =
569  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
570  SmallVector<int64_t> tileInterchange =
571  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
572 
573  scf::SCFTilingOptions tilingOptions;
574  tilingOptions.interchangeVector = tileInterchange;
575  SmallVector<OpFoldResult> tileSizesOfr =
576  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
577  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
578  scf::SCFTileAndFuseOptions tileAndFuseOptions;
579  tileAndFuseOptions.tilingOptions = tilingOptions;
580 
581  if (getApplyCleanup()) {
582  MLIRContext *context = rewriter.getContext();
583  RewritePatternSet patterns(context);
584  tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
587  tileAndFuseOptions.cleanupPatterns = std::move(patterns);
588  }
589 
590  LogicalResult result = applyTilingToAll(
591  rewriter, getOperation(), state.getPayloadOps(getTarget()),
592  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
593  [&](TilingInterface tilingInterfaceOp)
594  -> FailureOr<scf::SCFTileAndFuseResult> {
595  return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
596  tileAndFuseOptions);
597  });
599  : DiagnosedSilenceableFailure::success();
600 }
601 
602 LogicalResult transform::FuseOp::verify() {
603  SmallVector<int64_t> permutation =
604  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
605  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
606  if (!std::is_permutation(sequence.begin(), sequence.end(),
607  permutation.begin(), permutation.end())) {
608  return emitOpError() << "expects interchange to be a permutation, found "
609  << getTileInterchange();
610  }
611 
612  SmallVector<int64_t> sizes =
613  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
614  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
615  if (numExpectedLoops != getNumResults() - 1)
616  return emitOpError() << "expects " << numExpectedLoops << " loop results";
617 
618  return success();
619 }
620 
621 //===----------------------------------------------------------------------===//
622 // FuseIntoContainingOp
623 //===----------------------------------------------------------------------===//
624 
625 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
626  OperationState &result,
627  Value producerOp,
628  Value containingOp) {
629  result.addOperands({producerOp, containingOp});
630  auto resultType = transform::AnyOpType::get(builder.getContext());
631  result.addTypes({resultType, resultType});
632 }
633 
634 /// Add new operands to the forall op for users of the producerOp
635 /// that are dominated by the containing scf.forall op.
637  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
638  Operation *containingOp, TilingResult &tileAndFuseResult,
639  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
640  SmallVector<OpFoldResult> &sizes) {
641 
642  // Count number of users not including the containing op
643  SetVector<Operation *> dominatedUsers;
644  DominanceInfo domInfo(containingOp);
645  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
646  if (!containingOp->isAncestor(user) &&
647  (domInfo.dominates(containingOp, user))) {
648  dominatedUsers.insert(user);
649  }
650  }
651  if (dominatedUsers.empty())
652  return nullptr;
653 
654  // Create new scf.forall op
655  auto forallOp = cast<scf::ForallOp>(containingOp);
656  OpBuilder::InsertionGuard g(rewriter);
657  rewriter.setInsertionPoint(forallOp);
658 
659  // Get new output
660  Location loc = forallOp.getLoc();
661  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
662  if (!genericOp)
663  return nullptr;
664  SmallVector<Value> outputs = genericOp.getOutputs();
665  SmallVector<Value> newOuts(forallOp.getOutputs());
666  newOuts.push_back(outputs[resultNumber]);
667 
668  // Create new scf.forall op
669  auto newforallOp = scf::ForallOp::create(
670  rewriter, loc, forallOp.getMixedLowerBound(),
671  forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
672  forallOp.getMapping());
673  rewriter.eraseBlock(newforallOp.getBody());
674  newforallOp.getRegion().takeBody(forallOp.getRegion());
675 
676  // Add additional block argument for new value being returned
677  // and replaces all uses of the new output with corresponding bbArg
678  // inside the scf.forall to enable fusion into this new scf.forall.
679  newforallOp.getBody()->addArgument(newOuts.back().getType(),
680  newOuts.back().getLoc());
681  auto bbArgs = newforallOp.getBody()->getArguments();
682  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
683  [&](OpOperand &use) {
684  Operation *op = use.getOwner();
685  return newforallOp->isProperAncestor(op);
686  });
687 
688  // Fix terminator
689  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
690  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
691  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
692  Operation *firstYieldOp = yieldingOps.front();
693  rewriter.setInsertionPoint(firstYieldOp);
694  Value src = tileAndFuseResult.tiledValues[0];
695  Value dst = newforallOp.getRegionIterArgs().back();
696  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
697  tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->getLoc(), src,
698  dst, offsets, sizes, strides);
699 
700  for (auto result : llvm::enumerate(forallOp.getResults())) {
701  rewriter.replaceAllUsesWith(result.value(),
702  newforallOp->getResult(result.index()));
703  }
704  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
705  newforallOp->getResults().back(),
706  [&](OpOperand &use) {
707  Operation *user = use.getOwner();
708  return dominatedUsers.contains(user);
709  });
710  return newforallOp;
711 }
712 
713 /// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
714 /// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
715 /// outer loop. To determine the second condition, this function iterates
716 /// using a worklist over the enclosing loops, trying to find 'src' in any of
717 /// the parent loop's iter args.
718 static bool sameOrEquivalentIterArg(Value src, Value dst) {
719  // Stack like vector containing possible iterArgs candidates. The first one
720  // is dst, and we will transverse the IR from there.
721  SmallVector<Value> destWorklist;
722  destWorklist.push_back(dst);
723 
724  while (!destWorklist.empty()) {
725  Value currentDst = destWorklist.pop_back_val();
726 
727  // We have found the same operand in some iter arg in the loop structure,
728  // so src and dst are equivalent.
729  if (src == currentDst)
730  return true;
731 
732  // The operands are not equivalent, look for enclosing loops over
733  // currentDst.
734  auto bbArg = dyn_cast<BlockArgument>(currentDst);
735  if (!bbArg)
736  continue;
737 
738  Block *parentBlock = bbArg.getOwner();
739  assert(parentBlock && "unlinked block argument");
740 
741  Operation *parentOp = parentBlock->getParentOp();
742  assert(parentOp && "expected block argument with parent operation");
743 
744  // Check if parent is loop-like. If it's not, do not add it to the worklist.
745  auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
746  if (!parentLoop)
747  continue;
748 
749  for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
750  // No need to check for null as innerIterArg is tied to parentLoop.
751  OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
752  Value loopBlockArgument =
753  parentLoop->getOperand(operand->getOperandNumber());
754  destWorklist.push_back(loopBlockArgument);
755  }
756  }
757 
758  return false;
759 }
760 
761 /// Find the first "extract" user of `producerOp` and tile it right before its
762 /// use. The tiled op is fused under the `containingOp`.
763 /// Return this fused op on success or nullptr if anything fails.
764 /// If tiled op has uses that are dominated by `containingOp`, return
765 /// a new `containingOp` with results of the fused op appended to
766 /// results of the `containingOp` or nullptr if there are no dominated uses.
767 static std::tuple<SmallVector<Operation *>, Operation *>
769  Operation *producerOp, Operation *containingOp) {
770  LDBG() << "Try to fuse a direct extract use";
771  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
772  if (!tileableProducer) {
773  diag.attachNote(producerOp->getLoc())
774  << "producer is not a TileableInterface: " << *producerOp;
775  return {};
776  }
777 
778  // Search the producer slices accessed within the containing operation.
779  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
780  // evolve into an interface.
781  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
782  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
783  return sliceOp && containingOp->isProperAncestor(sliceOp);
784  });
785 
786  // Find a fusion opportunity.
787  if (it == tileableProducer->getUsers().end()) {
788  diag.attachNote(tileableProducer->getLoc())
789  << "could not find fusion opportunity for: " << *tileableProducer;
790  return {};
791  }
792  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
793 
794  // Try to fuse the producer in-place.
795  OpBuilder::InsertionGuard guard(rewriter);
796  rewriter.setInsertionPoint(sliceOpToTile);
797 
798  // Clone the producer inside the consumer and try to update the producer init
799  // operands using the loop bbArgs if applicable. More precisely, if the bbArg
800  // of the container loop points to a value that it is used by the consumer op,
801  // then, instead of using such value on the consumer, use the value coming
802  // from the bbArg instead. This allows to reuse the output tensor (instead of
803  // creating a new one) of the container when both producer and container write
804  // to the same output.
805  if (LoopLikeOpInterface containerLoop =
806  dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
807  Operation *clone = rewriter.clone(*producerOp);
808  rewriter.modifyOpInPlace(clone, [&]() {
809  // Iterate over the outputs of the producer and over the loop bbArgs and
810  // check if any bbArg points to the same value as the producer output. In
811  // such case, make the producer output point to the bbArg directly.
812  for (OpOperand &initOperandPtr :
813  cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
814  Value producerOperand =
815  clone->getOperand(initOperandPtr.getOperandNumber());
816  for (BlockArgument containerIterArg :
817  containerLoop.getRegionIterArgs()) {
818  OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
819  Value consumerOperand =
820  containerLoop->getOperand(bbArg->getOperandNumber());
821  // The producer has the same init as the loop bbArg, use it.
822  if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
823  initOperandPtr.set(containerIterArg);
824  }
825  }
826  }
827  });
828 
829  tileableProducer = dyn_cast<TilingInterface>(clone);
830  }
831 
832  // Tile the producer.
833  int64_t resultNumber =
834  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
835  LDBG() << "resultNumber: " << resultNumber;
836 
837  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
838  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
839 
840  FailureOr<TilingResult> tileAndFuseResult =
841  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
842  sizes);
843 
844  if (failed(tileAndFuseResult)) {
845  diag.attachNote(tileableProducer->getLoc())
846  << "failed to tile producer op: " << *tileableProducer;
847  return {};
848  }
849 
850 #ifndef NDEBUG
851  for (auto *tiledOp : tileAndFuseResult->tiledOps) {
852  LDBG() << "tiledProducer: " << *tiledOp;
853  }
854 #endif
855 
856  // Replace the extract op.
857  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
858  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
859  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
860  if (failed(maybeRankReduced)) {
861  diag.attachNote(producerOp->getLoc())
862  << "shape types don't match (missing canonicalization?):\nTiledOp: "
863  << tileAndFuseResult->tiledValues[0]
864  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
865  return {};
866  }
867  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
868 
869  // Add new outputs to containing op, if required
870  Operation *newContainingOp = replaceForAllWithNewSignature(
871  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
872  resultNumber, offsets, sizes);
873 
874  // Cleanup clone.
875  if (dyn_cast<LoopLikeOpInterface>(containingOp))
876  rewriter.eraseOp(tileableProducer);
877 
878  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
879 }
880 
881 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
882 /// it is exactly the `containingOp`, otherwise bail.
883 /// Then, find the first "extract" user of the tied block argument and tile it
884 /// right before its "extract" use. The tiled op is fused under the
885 /// `containingOp`.
886 /// Return this fused op on success or nullptr if anything fails.
889  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
890  Operation *containingOp) {
891  LDBG() << "Try to fuse an extract use through block argument";
892 
893  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
894  if (!tileableProducer) {
895  diag.attachNote(producerOp->getLoc())
896  << "producer is not a TileableInterface: " << *producerOp;
897  return {};
898  }
899 
900  // Search the first use by a "scf::ForallOp" user.
901  scf::ForallOp forallOp;
902  auto itProducerUses =
903  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
904  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
905  return forallOp;
906  });
907  // If it's not from the containing op, return.
908  if (!forallOp || forallOp != containingOp) {
909  diag.attachNote(tileableProducer->getLoc())
910  << "could not find a use by the containing op: " << *tileableProducer;
911  return {};
912  }
913 
914  // Search the producer slices accessed within the containing
915  // operation.
916  // TODO: Generalize to more extract/insert/parallel_insert triples.
917  // Maybe evolve into an interface.
918  OpOperand *pUse = &(*itProducerUses);
919  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
920 
921  // Search the producer slices accessed within the containing operation.
922  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
923  // evolve into an interface.
924  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
925  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
926  return sliceOp && containingOp->isProperAncestor(sliceOp);
927  });
928 
929  // Find a fusion opportunity.
930  if (itBBArgUsers == bbArg.getUsers().end()) {
931  diag.attachNote(containingOp->getLoc())
932  << "could not find fusion opportunity for bbArg: " << bbArg;
933  return {};
934  }
935  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
936 
937  // Try to fuse the producer in-place.
938  OpBuilder::InsertionGuard guard(rewriter);
939  rewriter.setInsertionPoint(sliceOpToTile);
940 
941  // Replace the use in the tileableProducer before tiling: clone, replace and
942  // then tile.
943  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
944  LDBG() << "resultNumber: " << resultNumber;
945 
946  // Gather destination tensors.
947  SmallVector<Value> destinationTensors;
949  rewriter, tileableProducer->getLoc(), tileableProducer,
950  destinationTensors))) {
951  diag.attachNote(tileableProducer->getLoc())
952  << "failed to get destination tensors for: " << *tileableProducer;
953  return {};
954  }
955 
956  IRMapping bvm;
957  bvm.map(destinationTensors[resultNumber], bbArg);
958  auto tileableProducerClone =
959  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
960  auto scopeGuard =
961  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
962 
963  // Tile the producer.
964  FailureOr<TilingResult> tileAndFuseResult =
965  tileableProducerClone.generateResultTileValue(
966  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
967  sliceOpToTile.getMixedSizes());
968  if (failed(tileAndFuseResult)) {
969  diag.attachNote(tileableProducer->getLoc())
970  << "failed to tile producer op: " << *tileableProducer;
971  return {};
972  }
973 
974  // Replace the extract op.
975  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
976  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
977  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
978  assert(succeeded(maybeRankReduced) && "unexpected shape");
979  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
980 
981  // Replace the use in containingOp.
982  rewriter.modifyOpInPlace(containingOp, [&]() {
983  containingOp->setOperand(pUse->getOperandNumber(),
984  destinationTensors.front());
985  });
986 
987  return tileAndFuseResult->tiledOps;
988 }
989 
991  Operation *producerOp,
992  Operation *containingOp) {
993  LDBG() << "Try to fuse an use by cloning";
994 
995  // Gather all uses inside the containing op.
997  for (OpResult result : producerOp->getOpResults()) {
998  for (OpOperand &use : result.getUses()) {
999  if (containingOp->isProperAncestor(use.getOwner())) {
1000  uses.push_back(&use);
1001  continue;
1002  }
1003  // Cannot clone and fuse if the use is by the containing op itself: fail
1004  // immediately.
1005  if (containingOp == use.getOwner()) {
1006  diag.attachNote(producerOp->getLoc())
1007  << "producer op use by containing op cannot be fused by cloning";
1008  return nullptr;
1009  }
1010  }
1011  }
1012 
1013  // Check for a non-empty list of fusion opportunities.
1014  if (uses.empty()) {
1015  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
1016  return nullptr;
1017  }
1018 
1019  // Clone and fuse inside the containing op.
1020  Operation *fusedOp = nullptr;
1021  OpOperand *use = uses.front();
1022  // Parallel insert slice is not a valid clone destination.
1023  // TODO: Generalize to other type of ops.
1024  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
1025  "Parallel insert slice is not a valid clone destination");
1026  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
1027  LDBG() << "resultNumber: " << resultNumber;
1028 
1029  OpBuilder::InsertionGuard guard(rewriter);
1030  rewriter.setInsertionPoint(use->getOwner());
1031  fusedOp = rewriter.clone(*producerOp);
1032  rewriter.modifyOpInPlace(
1033  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1034 
1035  return fusedOp;
1036 }
1037 
1038 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1039  // Allow repeated handles since we are fusing everything anyway.
1040  return true;
1041 }
1042 
1044 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
1045  transform::TransformResults &results,
1046  transform::TransformState &state) {
1047  SmallVector<Operation *> fusedOps;
1048  auto producerOps = state.getPayloadOps(getProducerOp());
1049  auto containingOps = state.getPayloadOps(getContainingOp());
1050  if (!llvm::hasSingleElement(containingOps)) {
1051  return emitDefiniteFailure()
1052  << "requires exactly one containing_op handle (got "
1053  << llvm::range_size(containingOps) << ")";
1054  }
1055  Operation *containingOp = *containingOps.begin();
1056 
1057  // If nothing to fuse, propagate success.
1058  if (std::empty(producerOps)) {
1059  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
1060  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1062  }
1063 
1064  // Helper function to find the next producer that should be fused. Take any
1065  // producer that has a use inside the containing op.
1066  SetVector<Operation *> remainingProducers(llvm::from_range, producerOps);
1067  auto getNextProducer = [&]() -> FailureOr<Operation *> {
1068  for (const auto &it : enumerate(remainingProducers)) {
1069  Operation *producerOp = it.value();
1070  // The containing op may be a user of producerOp: use isAncestor.
1071  int64_t numUsesInContainingOp =
1072  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
1073  return containingOp->isAncestor(op);
1074  });
1075  // TODO: When resolving the TODO below (no duplicate ops), take an op
1076  // that has no use among the remaining producers. This is a topological
1077  // sorting.
1078  if (numUsesInContainingOp > 0) {
1079  if (numUsesInContainingOp == 1)
1080  remainingProducers.erase(remainingProducers.begin() + it.index());
1081  return producerOp;
1082  }
1083  }
1084  return failure();
1085  };
1086 
1087  while (!remainingProducers.empty()) {
1088  auto nextProducer = getNextProducer();
1089  if (failed(nextProducer)) {
1090  auto diag = mlir::emitSilenceableFailure(getLoc())
1091  << "could not find next producer to fuse into container";
1092  diag.attachNote(containingOp->getLoc()) << "containing op";
1093  return diag;
1094  }
1095 
1096  Operation *producerOp = *nextProducer;
1097 
1098  // Default diagnostic, to be complemented with more failure information.
1100  diag << "could not fuse " << *producerOp << " into " << *containingOp;
1101 
1102  // TODO: If there are multiple uses of the producer in the containing op,
1103  // we currently tile/clone the op multiple times (once per use). In some
1104  // cases, we can tile/clone once and reuse the value for each use.
1105  // Futhermore, producers should then be traversed according to a
1106  // topological sorting.
1107  auto [tiledOps, newContainingOp] =
1108  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1109  if (!tiledOps.empty()) {
1110  LDBG() << "\nFused a direct extract use\n" << *containingOp;
1111  fusedOps.append(tiledOps);
1112  if (newContainingOp) {
1113  // Update handles associated with the containing op so we don't need to
1114  // invalidate them. This is a hack to support better composability
1115  // between tiling and fusion while a proper mechanism is being
1116  // investigated.
1117  //
1118  // DO NOT replicate this elsewhere unless you understand what you are
1119  // doing.
1120  LogicalResult replacementStatus =
1121  rewriter.notifyPayloadOperationReplaced(containingOp,
1122  newContainingOp);
1123  (void)replacementStatus;
1124  assert(succeeded(replacementStatus) &&
1125  "unable to update transform state mapping");
1126  rewriter.eraseOp(containingOp);
1127  containingOp = newContainingOp;
1128  }
1129  continue;
1130  }
1131 
1132  SmallVector<Operation *> tiledContainingOpOperand =
1134  rewriter, diag, producerOp, containingOp);
1135  if (!tiledContainingOpOperand.empty()) {
1136  LDBG() << "\nFused an extract use through block argument\n"
1137  << *containingOp;
1138  fusedOps.append(tiledContainingOpOperand);
1139  continue;
1140  }
1141 
1142  Operation *cloned =
1143  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1144  if (cloned) {
1145  LDBG() << "\nFused an use by cloning\n" << *containingOp;
1146  fusedOps.push_back(cloned);
1147  continue;
1148  }
1150  }
1151 
1152  results.set(cast<OpResult>(getFusedOp()), fusedOps);
1153  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1155 }
1156 
1157 void transform::FuseIntoContainingOp::getEffects(
1159  consumesHandle(getProducerOpMutable(), effects);
1160  onlyReadsHandle(getContainingOpMutable(), effects);
1161  producesHandle(getOperation()->getOpResults(), effects);
1162  modifiesPayload(effects);
1163 }
1164 
1165 //===----------------------------------------------------------------------===//
1166 // GeneralizeOp
1167 //===----------------------------------------------------------------------===//
1168 
1170 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1171  LinalgOp target,
1173  transform::TransformState &state) {
1174  // Exit early if no transformation is needed.
1175  if (isa<GenericOp>(target)) {
1176  results.push_back(target);
1178  }
1179  rewriter.setInsertionPoint(target);
1180  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1181  if (succeeded(generic)) {
1182  results.push_back(generic->getOperation());
1184  }
1185  return emitDefaultSilenceableFailure(target);
1186 }
1187 
1188 //===----------------------------------------------------------------------===//
1189 // SpecializeOp
1190 //===----------------------------------------------------------------------===/
1191 
1193 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1194  LinalgOp target,
1196  transform::TransformState &state) {
1197  // Exit early if the operation is not a generic.
1198  if (!isa<GenericOp>(target)) {
1199  results.push_back(target);
1201  }
1202  rewriter.setInsertionPoint(target);
1203  FailureOr<LinalgOp> named =
1204  specializeGenericOp(rewriter, cast<GenericOp>(target));
1205  if (succeeded(named)) {
1206  results.push_back(named->getOperation());
1208  }
1209  return emitDefaultSilenceableFailure(target);
1210 }
1211 
1212 //===----------------------------------------------------------------------===//
1213 // InterchangeOp
1214 //===----------------------------------------------------------------------===//
1215 
1217 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1218  GenericOp target,
1220  transform::TransformState &state) {
1221  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1222  // Exit early if no transformation is needed.
1223  if (interchangeVector.empty()) {
1224  results.push_back(target);
1226  }
1227 
1228  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1229  if (interchangeVector.size() != numLoops) {
1230  return emitSilenceableError()
1231  << getIteratorInterchangeAttrName() << " has length ("
1232  << interchangeVector.size()
1233  << ") different from the number of loops in the target operation ("
1234  << numLoops << ")";
1235  }
1236  FailureOr<GenericOp> res = interchangeGenericOp(
1237  rewriter, target, SmallVector<unsigned>(interchangeVector));
1238  if (failed(res))
1239  return emitDefiniteFailure() << "failed to apply";
1240  results.push_back(res->getOperation());
1242 }
1243 
1244 LogicalResult transform::InterchangeOp::verify() {
1245  ArrayRef<int64_t> permutation = getIteratorInterchange();
1246  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1247  if (!std::is_permutation(sequence.begin(), sequence.end(),
1248  permutation.begin(), permutation.end())) {
1249  return emitOpError()
1250  << "expects iterator_interchange to be a permutation, found "
1251  << getIteratorInterchange();
1252  }
1253  return success();
1254 }
1255 
1256 //===----------------------------------------------------------------------===//
1257 // LinalgCopyToMemrefOp
1258 //===----------------------------------------------------------------------===//
1259 
1260 DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1261  transform::TransformRewriter &rewriter, Operation *targetOp,
1263  transform::TransformState &state) {
1264 
1265  // Check if the target can be converted.
1266  if (!isa<linalg::CopyOp>(targetOp)) {
1268  emitSilenceableError() << "only linalg.copy target ops are supported";
1269  diag.attachNote(targetOp->getLoc()) << "target op";
1270  return diag;
1271  }
1272 
1273  auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1274  if (!copyOp.hasPureBufferSemantics()) {
1276  emitSilenceableError()
1277  << "cannot transform a linalg.copy on tensors into a memref.copy";
1278  diag.attachNote(targetOp->getLoc()) << "target op";
1279  return diag;
1280  }
1281 
1282  SmallVector<Value> inputs = copyOp.getInputs();
1283  SmallVector<Value> outputs = copyOp.getOutputs();
1284  assert(inputs.size() == 1 && "expected linalg copy op with one input");
1285  assert(outputs.size() == 1 && "expected memref copy op with one output");
1286  Value input = inputs.front();
1287  Value output = outputs.front();
1288 
1289  // linalg.copy supports different element types on source/dest whereas
1290  // memref.copy does not, so we must check that the source and dest types can
1291  // be handled by memref.copy and otherwise reject the transformation.
1292  if (!isa<ShapedType>(input.getType())) {
1294  emitSilenceableError()
1295  << "cannot transform a linalg.copy which input has no shape";
1296  diag.attachNote(targetOp->getLoc()) << "target op";
1297  return diag;
1298  }
1299 
1300  // linalg.copy destination must be a shaped type.
1301  assert(isa<ShapedType>(output.getType()));
1302 
1303  if (cast<ShapedType>(input.getType()).getElementType() !=
1304  cast<ShapedType>(output.getType()).getElementType()) {
1306  emitSilenceableError()
1307  << "cannot transform a linalg.copy with different source and "
1308  "destination element types ";
1309  diag.attachNote(targetOp->getLoc()) << "target op";
1310  return diag;
1311  }
1312 
1313  // Target can be converted, do it.
1314  auto memrefCopyOp =
1315  rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1316 
1317  results.push_back(memrefCopyOp);
1319 }
1320 
1321 //===----------------------------------------------------------------------===//
1322 // LowerPackOp
1323 //===----------------------------------------------------------------------===//
1324 
1325 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1326  transform::TransformRewriter &rewriter, linalg::PackOp target,
1327  transform::ApplyToEachResultList &transformResults,
1328  transform::TransformState &state) {
1329  rewriter.setInsertionPoint(target);
1330  bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1331  FailureOr<LowerPackResult> res =
1332  lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1333  if (failed(res)) {
1334  return mlir::emitSilenceableFailure(target->getLoc())
1335  << "cannot lower to pad + expand + transpose";
1336  }
1337  transformResults.push_back(res->padOp);
1338  transformResults.push_back(res->expandShapeOp);
1339  transformResults.push_back(res->transposeOp);
1341 }
1342 
1343 //===----------------------------------------------------------------------===//
1344 // LowerUnPackOp
1345 //===----------------------------------------------------------------------===//
1346 
1347 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1348  transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1349  transform::ApplyToEachResultList &transformResults,
1350  transform::TransformState &state) {
1351  rewriter.setInsertionPoint(target);
1352  bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1353  FailureOr<LowerUnPackOpResult> res =
1354  lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1355  if (failed(res)) {
1357  emitSilenceableError()
1358  << "cannot lower to transpose + collapse + extract";
1359  diag.attachNote(target->getLoc()) << "target payload op";
1360  return diag;
1361  }
1362  transformResults.push_back(res->emptyOp);
1363  transformResults.push_back(res->transposeOp);
1364  transformResults.push_back(res->collapseShapeOp);
1365  transformResults.push_back(res->extractSliceOp);
1367 }
1368 
1369 //===---------------------------------------------------------------------===//
1370 // MatchOp
1371 //===---------------------------------------------------------------------===//
1372 
1373 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1374  Value target, ArrayRef<StringRef> opNames) {
1375  result.addOperands(target);
1376  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1377  builder.getStrArrayAttr(opNames));
1378  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1379 }
1380 
1381 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1382  TypeRange resultTypes, Value target,
1383  ArrayRef<StringRef> opNames) {
1384  result.addOperands(target);
1385  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1386  builder.getStrArrayAttr(opNames));
1387  result.addTypes(resultTypes);
1388 }
1389 
1391 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1392  transform::TransformResults &results,
1393  transform::TransformState &state) {
1394  llvm::StringSet<> strs;
1395  if (getOps().has_value())
1396  strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1397 
1398  auto payloadOps = state.getPayloadOps(getTarget());
1399  if (!llvm::hasSingleElement(payloadOps)) {
1400  return emitDefiniteFailure("requires exactly one target handle");
1401  }
1402 
1404  bool incorrectNumOperandTypes = false;
1405  auto matchFun = [&](Operation *op) {
1406  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1407  return;
1408 
1409  // Interfaces cannot be matched by name, just by ID.
1410  // So we specifically encode the interfaces we care about for this op.
1411  if (getInterface().has_value()) {
1412  auto iface = getInterface().value();
1413  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1414  !isa<LinalgOp>(op))
1415  return;
1416  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1417  !isa<TilingInterface>(op))
1418  return;
1419  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1420  !isa<LoopLikeOpInterface>(op))
1421  return;
1422  }
1423 
1424  // Check if all specified attributes match.
1425  if (getOpAttrs().has_value()) {
1426  DictionaryAttr opAttrs = getOpAttrs().value();
1427  for (NamedAttribute attr : opAttrs) {
1428  if (attr.getName() == getInterfaceAttrName() ||
1429  attr.getName() == getOpsAttrName())
1430  continue;
1431  if (!op->hasAttr(attr.getName()))
1432  return;
1433  if (op->getAttr(attr.getName()) != attr.getValue())
1434  return;
1435  }
1436  }
1437 
1438  if (getFilterResultType().has_value()) {
1439  Type t = getFilterResultType().value();
1440  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1441  return;
1442  }
1443 
1444  if (getFilterOperandTypes().has_value()) {
1445  mlir::ArrayAttr types = getFilterOperandTypes().value();
1446  auto operandTypes = op->getOperandTypes();
1447 
1448  if (types.size() == 1) {
1449  // All the operands must must be equal to the specified type
1450  auto typeattr =
1451  dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1452  Type t = cast<::mlir::Type>(typeattr.getValue());
1453  if (!llvm::all_of(op->getOperandTypes(),
1454  [&](Type operandType) { return operandType == t; }))
1455  return;
1456  } else {
1457  // The operand types must match all the types in the list (in the same
1458  // order in with they are specified)
1459  if (types.size() != operandTypes.size()) {
1460  incorrectNumOperandTypes = true;
1461  return;
1462  }
1463 
1464  for (auto [attr, operandType] :
1465  llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1466  auto typeattr = cast<mlir::TypeAttr>(attr);
1467  Type type = cast<::mlir::Type>(typeattr.getValue());
1468 
1469  if (type != operandType)
1470  return;
1471  }
1472  }
1473  }
1474 
1475  // All constraints are satisfied.
1476  res.push_back(op);
1477  return;
1478  };
1479 
1480  (*payloadOps.begin())->walk(matchFun);
1481  if (incorrectNumOperandTypes)
1482  return emitDefiniteFailure("If filter_operand_types contains more than a "
1483  "type, then it must contain as much types as "
1484  "the number of operands in the target ops");
1485  results.set(cast<OpResult>(getResult()), res);
1487 }
1488 
1489 //===---------------------------------------------------------------------===//
1490 // MultiTileSizesOp
1491 //===---------------------------------------------------------------------===//
1492 
1494  Type targetType, Type lowSizeType, Type,
1495  Type) {
1496  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1497 }
1498 
1499 static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1500  Type &targetType, Type &lowSizeType,
1501  Type &highSizeType,
1502  Type &splitPointType) {
1503  FunctionType funcType;
1504  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1505  if (failed(parser.parseType<FunctionType>(funcType)))
1506  return failure();
1507 
1508  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1509  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1510  "argument and one result";
1511  }
1512  targetType = funcType.getInput(0);
1513  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1514 
1515  return success();
1516 }
1517 
1518 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1519  transform::TransformRewriter &rewriter, LinalgOp target,
1521  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1522  if (target.hasDynamicShape()) {
1523  auto diag = emitSilenceableError()
1524  << "cannot compute parametric tile sizes for dynamically "
1525  "shaped payload op";
1526  diag.attachNote(target->getLoc()) << "payload op";
1527  return diag;
1528  }
1529 
1530  FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1531  target, getDimension(), getTargetSize(), getDivisor());
1532  if (failed(spec)) {
1533  return emitSilenceableError()
1534  << "failed to compute multi-size tiling sizes";
1535  }
1536 
1537  Builder builder(target.getContext());
1538  results.assign(llvm::map_range(
1539  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1540  spec->lowTileSize * spec->lowTripCount}),
1541  [&builder, this](int64_t value) {
1542  return builder.getIntegerAttr(
1543  cast<ParamType>(getLowSize().getType()).getType(), value);
1544  }));
1546  }
1547 
1548  OpBuilder builder(target.getContext());
1549  builder.setInsertionPoint(target);
1550  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1551  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1552  FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1553  builder, target, getDimension(), targetSize, divisor);
1554  if (failed(spec)) {
1555  return emitSilenceableError() << "could not generate tile size computation";
1556  }
1557 
1558  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1559  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1560  Operation *splitPoint =
1561  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1562  {spec->lowTileSize, spec->lowTripCount});
1563  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1564  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1565  assert(lowTileSize && highTileSize && splitPoint &&
1566  "tile sizes are not produced by operations");
1567  results.reserve(results.size() + 3);
1568  results.push_back(lowTileSize);
1569  results.push_back(highTileSize);
1570  results.push_back(splitPoint);
1572 }
1573 
1574 void transform::MultiTileSizesOp::getEffects(
1576  onlyReadsHandle(getTargetMutable(), effects);
1577  producesHandle(getOperation()->getOpResults(), effects);
1578  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1579  onlyReadsPayload(effects);
1580  else
1581  modifiesPayload(effects);
1582 }
1583 
1584 LogicalResult transform::MultiTileSizesOp::verify() {
1585  if (getLowSize().getType() != getHighSize().getType() ||
1586  getLowSize().getType() != getSplitPoint().getType()) {
1587  return emitOpError() << "expects all results type to be the same";
1588  }
1589  return success();
1590 }
1591 
1592 //===---------------------------------------------------------------------===//
1593 // PackOp
1594 //===---------------------------------------------------------------------===//
1595 
1596 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1597  Value target,
1598  ArrayRef<OpFoldResult> mixedPackedSizes) {
1599  SmallVector<int64_t> staticPackedSizes;
1600  SmallVector<Value> dynamicPackedSizes;
1601  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1602  staticPackedSizes);
1603  // Call the default builder which sets up the proper operands segment sizes
1604  // attributes for multiple variadic operands. In the absence of this, horrible
1605  // bugs ensue.
1606  Type linalgOpHType = transform::OperationType::get(
1607  builder.getContext(), GenericOp::getOperationName());
1608  build(builder, result,
1609  /*resultType=*/linalgOpHType,
1610  /*target=*/target,
1611  /*dynamic_sizes=*/dynamicPackedSizes,
1612  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1613 }
1614 
1615 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1616  Builder b(getContext());
1617  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1618 }
1619 
1621 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1622  transform::TransformResults &transformResults,
1623  transform::TransformState &state) {
1624  auto targetOps = state.getPayloadOps(getTarget());
1625  // If nothing to pack, propagate success.
1626  if (std::empty(targetOps)) {
1627  transformResults.set(cast<OpResult>(getPackedOp()),
1628  ArrayRef<Operation *>({}));
1630  }
1631  // Fail on multi-op handles.
1632  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1633  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1634  return emitSilenceableError()
1635  << "requires target to map to exactly 1 LinalgOp (got "
1636  << llvm::range_size(targetOps) << ")";
1637  }
1638  // Fail on mismatched number of pack sizes.
1639  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1640  return emitSilenceableError()
1641  << "requires number of packed sizes match the number of loops ("
1642  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1643  << ")";
1644  }
1645 
1646  // Unpack handles to constants or actual SSA index values.
1647  SmallVector<OpFoldResult> packedSizes;
1649  state, *this, packedSizes, getMixedPackedSizes());
1650 
1651  rewriter.setInsertionPoint(linalgOp);
1652  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1653  if (failed(maybeResult))
1654  return emitDefiniteFailure("data tiling failed");
1655 
1656  transformResults.set(cast<OpResult>(getPackedOp()),
1657  {maybeResult->packedLinalgOp.getOperation()});
1659 }
1660 
1661 void transform::PackOp::getEffects(
1663  transform::consumesHandle(getTargetMutable(), effects);
1664  transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1665  transform::producesHandle(getOperation()->getOpResults(), effects);
1666  transform::modifiesPayload(effects);
1667 }
1668 
1669 //===---------------------------------------------------------------------===//
1670 // PackGreedilyOp.
1671 //===---------------------------------------------------------------------===//
1672 
1673 LogicalResult transform::PackGreedilyOp::verify() {
1674  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1675  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1676  << " is not a valid permutation";
1677  }
1678  // TODO: relax to allow empty once we have another strategy than just matmul.
1679  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1680  for (auto [s, nmo] :
1681  llvm::zip_equal(getMixedMatmulPackedSizes(),
1682  getMatmulPaddedSizesNextMultipleOf())) {
1683  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1684  if (nmo != 0 &&
1685  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1686  return emitOpError() << "at most one of the packed_size and the "
1687  "padded_sizes_next_multiple_of can be nonzero "
1688  "for the matmul strategy";
1689  }
1690  }
1691  }
1692  return success();
1693 }
1694 
1696 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1697  transform::TransformResults &transformResults,
1698  transform::TransformState &state) {
1699  SmallVector<Operation *> results;
1700  for (Operation *op : state.getPayloadOps(getTarget())) {
1701  auto linalgOp = dyn_cast<LinalgOp>(op);
1702  if (!linalgOp)
1703  continue;
1704  // linalgOp will be replaced and the insertion point may be invalidated if
1705  // we set it before -> set it after.
1706  rewriter.setInsertionPointAfter(linalgOp);
1707  // Failing to pack greedily is perfectly fine.
1708  // In the future we will want to order packings according to some metric.
1709  FailureOr<PackResult> packResult = packMatmulGreedily(
1710  /*rewriter=*/rewriter,
1711  /*linalgOp=*/linalgOp,
1712  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1713  /*mnkPaddedSizesNextMultipleOf=*/
1714  getMatmulPaddedSizesNextMultipleOf(),
1715  /*mnkOrder=*/getMatmulInnerDimsOrder());
1716  if (succeeded(packResult)) {
1717  results.push_back(packResult->packedLinalgOp);
1718  continue;
1719  }
1720  results.push_back(linalgOp);
1721  }
1722  transformResults.set(cast<OpResult>(getPackedOp()), results);
1724 }
1725 
1726 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1727  Builder b(getContext());
1728  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1729  b);
1730 }
1731 
1732 void transform::PackGreedilyOp::getEffects(
1734  transform::consumesHandle(getTargetMutable(), effects);
1735  transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1736  transform::producesHandle(getOperation()->getOpResults(), effects);
1737  transform::modifiesPayload(effects);
1738 }
1739 
1740 //===---------------------------------------------------------------------===//
1741 // PackTransposeOp
1742 //===---------------------------------------------------------------------===//
1743 
1744 LogicalResult transform::PackTransposeOp::verify() {
1745  if (!isPermutationVector(getInnerPerm())) {
1746  return emitOpError() << getInnerPermAttrName()
1747  << " is not a valid permutation";
1748  }
1749  if (!isPermutationVector(getOuterPerm())) {
1750  return emitOpError() << getOuterPermAttrName()
1751  << " is not a valid permutation";
1752  }
1753  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1754  return emitOpError() << " at least one of " << getInnerPermAttrName()
1755  << " or " << getOuterPermAttrName()
1756  << " must be specified";
1757  }
1758  return success();
1759 }
1760 
1761 namespace {
1762 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1763 } // namespace
1764 
1765 /// Return true if `permutation` is a valid permutation of the
1766 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1767 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1768 /// This is the case when the `permutation` rank matches the rank expected by
1769 /// `op` and `permutation` is itself a permutation vector.
1770 /// Return true if either `op` or `permutation` are empty to allow a simpler
1771 /// polymorphic implementation.
1772 template <typename RelayoutOpTy>
1774  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1775  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1776  static_assert(
1777  llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1778  "applies to only pack or unpack operations");
1779  if (!op || permutation.empty())
1780  return true;
1781  size_t innerRank = op.getInnerDimsPos().size();
1782  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1783  return permutation.size() == innerRank && isPermutationVector(permutation);
1784  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1785  // Don't rely on it.
1786  if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1787  return permutation.size() == op.getSourceRank() &&
1788  isPermutationVector(permutation);
1789  }
1790  return permutation.size() == op.getDestRank() &&
1791  isPermutationVector(permutation);
1792 }
1793 
1795 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1796  transform::TransformResults &transformResults,
1797  transform::TransformState &state) {
1798  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1799  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1800  // Step 1. If nothing to pack, propagate success.
1801  if (std::empty(packOrUnpackOps)) {
1802  transformResults.set(cast<OpResult>(getPackedOp()), {});
1803  transformResults.set(cast<OpResult>(getPackOp()), {});
1804  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1806  }
1807 
1808  // Step 2. Bunch of runtime sanity check and error messages.
1809  // Step 2.1. Fail on multi-op handles.
1810  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1811  !llvm::hasSingleElement(linalgOps)) {
1812  return emitSilenceableError()
1813  << "requires target to map to exactly 1 "
1814  "packing op and 1 packed op ("
1815  << "got " << llvm::range_size(packOrUnpackOps) << " and "
1816  << llvm::range_size(linalgOps) << ")";
1817  }
1818 
1819  // Step 2.2. Fail on wrong type.
1820  auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1821  auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1822  if ((!packOp && !unPackOp)) {
1823  return emitSilenceableError() << "requires target to map to a "
1824  "linalg.pack or linalg.unpack";
1825  }
1826  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1827  if (!linalgOpTarget)
1828  return emitSilenceableError() << "requires a LinalgOp target";
1829 
1830  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1831  LinalgOp linalgOp;
1832  if (packOp && packOp.getResult().hasOneUse())
1833  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1834  else if (unPackOp)
1835  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1836  if (linalgOp != linalgOpTarget) {
1837  auto errorMsg =
1838  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1839  : StringLiteral{"not produced by the LinalgOp target"};
1840  return emitSilenceableError() << errorMsg;
1841  }
1842 
1843  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1844  // PackOp.
1845  if (unPackOp) {
1846  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1847  OpOperand *packUse = linalgOp.getDpsInitOperand(
1848  cast<OpResult>(unPackOp.getSource()).getResultNumber());
1849  packOp = packUse->get().getDefiningOp<linalg::PackOp>();
1850  if (!packOp || !packOp.getResult().hasOneUse())
1851  return emitSilenceableError() << "could not find matching pack op";
1852  }
1853 
1854  // Step 2.5. Fail if any permutation does not validate.
1855  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1856  ArrayRef<int64_t> perm =
1857  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1858  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1859  ? StringLiteral{"invalid outer_perm"}
1860  : StringLiteral{"invalid inner_perm"};
1861  if (!isValidPackingPermutation(packOp, perm, permType) ||
1862  !isValidPackingPermutation(unPackOp, perm, permType)) {
1863  Operation *packOrUnpackOp =
1864  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1865  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1866  }
1867  }
1868 
1869  // From here on, packOp and linalgOp are always present, unPackOp may or may
1870  // not be present.
1871  assert(packOp && linalgOp && "unexpected null op");
1872 
1873  // Step 3. Actually transpose the ops.
1874  FailureOr<PackTransposeResult> res = packTranspose(
1875  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1876  // Preconditions have been checked, it is an error to fail here.
1877  assert(succeeded(res) && "unexpected packTranspose failure");
1878 
1879  // Step 4. Return results.
1880  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1881  transformResults.set(cast<OpResult>(getPackedOp()),
1882  {res->transposedLinalgOp});
1883  if (unPackOp) {
1884  transformResults.set(cast<OpResult>(getUnPackOp()),
1885  {res->transposedUnPackOp});
1886  } else {
1887  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1888  }
1889 
1891 }
1892 
1893 //===---------------------------------------------------------------------===//
1894 // PadOp
1895 //===---------------------------------------------------------------------===//
1896 
1897 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1898  ArrayRef<int64_t> paddingDimensions,
1899  ArrayRef<int64_t> padToMultipleOf,
1900  ArrayRef<int64_t> nofoldFlags,
1901  ArrayRef<Attribute> transposePaddings,
1902  StringRef copyBackOp,
1903  bool usePrescribedTensorShapes) {
1904  auto resultType = transform::AnyOpType::get(b.getContext());
1905  return build(/*builder=*/b,
1906  /*result=*/result,
1907  /*types=*/TypeRange{resultType, resultType},
1908  /*target=*/target,
1909  /*paddingValues=*/ArrayAttr(), // let inference handle this
1910  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1911  /*padToMultipleOf=*/ValueRange{},
1912  /*padToMultipleOf=*/
1913  (padToMultipleOf.empty()
1914  ? DenseI64ArrayAttr()
1915  : b.getDenseI64ArrayAttr(padToMultipleOf)),
1916  /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1917  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1918  /*copyBackOp=*/b.getStringAttr(copyBackOp),
1919  /*usePrescribedTensorShapes=*/
1920  usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
1921 }
1922 
1923 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1924  ArrayRef<int64_t> paddingDimensions,
1925  ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1926  ArrayRef<int64_t> nofoldFlags,
1927  ArrayRef<Attribute> transposePaddings,
1928  StringRef copyBackOp,
1929  bool usePrescribedTensorShapes) {
1930  auto resultType = transform::AnyOpType::get(b.getContext());
1931  SmallVector<int64_t> staticPadToMultipleOf;
1932  SmallVector<Value> dynamicPadToMultipleOf;
1933  dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
1934  staticPadToMultipleOf);
1935  return build(/*builder=*/b,
1936  /*result=*/result,
1937  /*types=*/TypeRange{resultType, resultType},
1938  /*target=*/target,
1939  /*paddingValues=*/ArrayAttr(), // let inference handle this
1940  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1941  /*padToMultipleOf=*/dynamicPadToMultipleOf,
1942  /*padToMultipleOf=*/staticPadToMultipleOf,
1943  /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1944  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1945  /*copyBackOp=*/copyBackOp,
1946  /*usePrescribedTensorShapes=*/usePrescribedTensorShapes);
1947 }
1948 
1949 void PadOp::getEffects(
1951  consumesHandle(getTargetMutable(), effects);
1952  onlyReadsHandle(getPadToMultipleOfMutable(), effects);
1953  producesHandle(getOperation()->getOpResults(), effects);
1954  modifiesPayload(effects);
1955 }
1956 
1957 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1958  Builder b(getContext());
1959  return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1960 }
1961 
1963 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1964  transform::TransformResults &results,
1965  transform::TransformState &state) {
1966  auto transformOp = cast<TransformOpInterface>(getOperation());
1967  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1968 
1969  for (Operation *target : state.getPayloadOps(getTarget())) {
1970  auto linalgTarget = dyn_cast<LinalgOp>(target);
1971  if (!linalgTarget) {
1972  auto diag = emitSilenceableError() << "expected LinalgOp target";
1973  diag.attachNote(target->getLoc()) << "target op";
1974  return diag;
1975  }
1976 
1977  // Convert the integer packing flags to booleans.
1978  SmallVector<bool> nofoldFlags;
1979  for (int64_t packPadding :
1980  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1981  nofoldFlags.push_back(static_cast<bool>(packPadding));
1982 
1983  // Convert the padding values to attributes.
1984  SmallVector<Attribute> paddingValues;
1985  for (auto const &[untypedAttr, elementOrTensorType] :
1986  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1987 
1988  if (isa<ub::PoisonAttr>(untypedAttr)) {
1989  paddingValues.push_back(untypedAttr);
1990  continue;
1991  }
1992  auto attr = dyn_cast<TypedAttr>(untypedAttr);
1993  if (!attr) {
1994  emitOpError("expects padding values to be typed attributes or poison");
1996  }
1997  Type elementType = getElementTypeOrSelf(elementOrTensorType);
1998  // Try to parse string attributes to obtain an attribute of element type.
1999  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2000  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2001  stringAttr, getContext(), elementType,
2002  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2003  if (!parsedAttr || parsedAttr.getType() != elementType) {
2004  auto diag = this->emitOpError("expects a padding that parses to ")
2005  << elementType << ", got " << untypedAttr;
2006  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2008  }
2009  paddingValues.push_back(parsedAttr);
2010  continue;
2011  }
2012  // Otherwise, add the attribute directly.
2013  if (attr.getType() != elementType) {
2014  auto diag = this->emitOpError("expects a padding value of type ")
2015  << elementType << ", got " << attr;
2016  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2018  }
2019  paddingValues.push_back(attr);
2020  }
2021 
2022  // Extract the transpose vectors.
2023  SmallVector<SmallVector<int64_t>> transposePaddings;
2024  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2025  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2026  cast<ArrayAttr>(transposeVector)));
2027 
2028  LinalgOp paddedOp;
2030  options.paddingDimensions =
2031  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2032 
2033  SmallVector<int64_t> padToMultipleOf;
2035  state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2036  if (!status.succeeded())
2037  return status;
2038  if (padToMultipleOf.empty())
2039  padToMultipleOf =
2040  SmallVector<int64_t>(options.paddingDimensions.size(), 1);
2041 
2042  options.padToMultipleOf = padToMultipleOf;
2043  options.paddingValues = paddingValues;
2044  options.nofoldFlags = nofoldFlags;
2045  if (getCopyBackOp() ==
2046  bufferization::MaterializeInDestinationOp::getOperationName()) {
2049  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2051  } else if (getCopyBackOp() == kCopyOpNone) {
2053  } else {
2054  llvm_unreachable("unsupported copy_back op");
2055  }
2056  // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2057  bool irChanged = false;
2058  if (getUsePrescribedTensorShapes() &&
2059  linalgTarget.hasPureTensorSemantics()) {
2060  OpBuilder::InsertionGuard g(rewriter);
2061  rewriter.setInsertionPoint(linalgTarget);
2062  for (OpOperand &operand : linalgTarget->getOpOperands()) {
2063  for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2064  if (ShapedType::isStatic(dim))
2065  continue;
2066  options.setSizeToPadTo(operand.getOperandNumber(), i,
2067  tensor::getMixedSize(rewriter,
2068  operand.get().getLoc(),
2069  operand.get(), i));
2070  irChanged = true;
2071  }
2072  }
2073  }
2074 
2075  SmallVector<Value> replacements;
2076  SmallVector<tensor::PadOp> newPadOps;
2077  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
2078  replacements, newPadOps))) {
2079  if (irChanged) {
2080  auto diag = emitDefiniteFailure() << "failed to pad op";
2081  diag.attachNote(target->getLoc()) << "target op";
2082  return diag;
2083  }
2084  auto diag = emitSilenceableError() << "failed to pad op";
2085  diag.attachNote(target->getLoc()) << "target op";
2086  return diag;
2087  }
2088 
2089  // We need to perform our own replacement here because this API is still
2090  // used in patterns that "pad and hoist", for which the replacement values
2091  // need to be different.
2092  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2093  // that we have more composable abstractions.
2094  rewriter.replaceOp(linalgTarget, replacements);
2095  paddedOps.push_back(paddedOp);
2096  padOps.append(newPadOps.begin(), newPadOps.end());
2097  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2098  for (Value v : replacements) {
2099  Operation *copyBackOp = v.getDefiningOp();
2100  if (!llvm::is_contained(copyBackOps, copyBackOp))
2101  copyBackOps.push_back(copyBackOp);
2102  }
2103  }
2104  }
2105 
2106  results.set(cast<OpResult>(getPadded()), paddedOps);
2107  results.set(cast<OpResult>(getPad()), padOps);
2108  results.set(cast<OpResult>(getCopy()), copyBackOps);
2110 }
2111 
2112 LogicalResult transform::PadOp::verify() {
2113  SmallVector<int64_t> nofoldFlags =
2114  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2115  if (any_of(nofoldFlags, [](int64_t packPadding) {
2116  return packPadding != 0 && packPadding != 1;
2117  })) {
2118  return emitOpError()
2119  << "expects nofold_flags to contain booleans (0/1), found "
2120  << getNofoldFlags();
2121  }
2122 
2123  SmallVector<int64_t> paddingDimensions =
2124  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2125  if (any_of(paddingDimensions,
2126  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2127  return emitOpError() << "expects padding_dimensions to contain positive "
2128  "integers, found "
2129  << getPaddingDimensions();
2130  }
2131  if (!getMixedPadToMultipleOf().empty()) {
2132  if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2133  return emitOpError() << "expects as many multiples as padding_dimensions";
2134  }
2135  }
2136  ArrayAttr transposes = getTransposePaddings();
2137  for (Attribute attr : transposes) {
2138  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2139  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2140  if (!std::is_permutation(sequence.begin(), sequence.end(),
2141  transpose.begin(), transpose.end())) {
2142  return emitOpError()
2143  << "expects transpose_paddings to be a permutation, found "
2144  << attr;
2145  }
2146  }
2147  if (getCopyBackOp() !=
2148  bufferization::MaterializeInDestinationOp::getOperationName() &&
2149  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2150  getCopyBackOp() != kCopyOpNone)
2151  return emitOpError() << "invalid copy_back_op";
2152  return success();
2153 }
2154 
2155 //===---------------------------------------------------------------------===//
2156 // PadTilingInterfaceOp
2157 //===---------------------------------------------------------------------===//
2158 
2159 void transform::PadTilingInterfaceOp::build(OpBuilder &b,
2160  OperationState &result,
2161  Value target,
2162  ArrayRef<int64_t> paddingSizes,
2163  bool padToMultipleOf) {
2164  auto resultType = transform::AnyOpType::get(b.getContext());
2165  return build(/*builder=*/b,
2166  /*result=*/result,
2167  /*types=*/TypeRange{resultType, resultType},
2168  /*target=*/target,
2169  /*paddingValues=*/ArrayAttr(), // let inference handle this
2170  /*paddingSizes=*/ValueRange{},
2171  /*paddingSizes=*/
2172  (paddingSizes.empty() ? DenseI64ArrayAttr()
2173  : b.getDenseI64ArrayAttr(paddingSizes)),
2174  /*padToMultipleOf=*/
2175  padToMultipleOf ? b.getUnitAttr() : nullptr);
2176 }
2177 
2178 void transform::PadTilingInterfaceOp::build(
2179  OpBuilder &b, OperationState &result, Value target,
2180  ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
2181  auto resultType = transform::AnyOpType::get(b.getContext());
2182  SmallVector<int64_t> staticPaddingSizes;
2183  SmallVector<Value> dynamicPaddingSizes;
2184  dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
2185  staticPaddingSizes);
2186  return build(/*builder=*/b,
2187  /*result=*/result,
2188  /*types=*/TypeRange{resultType, resultType},
2189  /*target=*/target,
2190  /*paddingValues=*/ArrayAttr(), // let inference handle this
2191  /*paddingSizes=*/dynamicPaddingSizes,
2192  /*paddingSizes=*/staticPaddingSizes,
2193  /*usePrescribedTensorShapes=*/padToMultipleOf);
2194 }
2195 
2196 void transform::PadTilingInterfaceOp::getEffects(
2198  consumesHandle(getTargetMutable(), effects);
2199  onlyReadsHandle(getPaddingSizesMutable(), effects);
2200  producesHandle(getOperation()->getOpResults(), effects);
2201  modifiesPayload(effects);
2202 }
2203 
2205 transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2206  Builder b(getContext());
2207  return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2208 }
2209 
2211 transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2212  transform::TransformResults &results,
2213  transform::TransformState &state) {
2214  SmallVector<Operation *> paddedOps, padOps;
2215 
2216  for (Operation *target : state.getPayloadOps(getTarget())) {
2217  auto targetOp = dyn_cast<TilingInterface>(target);
2218  if (!targetOp) {
2219  auto diag = emitSilenceableError() << "expected TilingInterface target";
2220  diag.attachNote(target->getLoc()) << "target op";
2221  return diag;
2222  }
2223 
2224  // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
2225  // loopsToOperand map / C++ APIs to compute the effect of padding on
2226  // operands.
2227  if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2228  auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
2229  "supported atm";
2230  diag.attachNote(target->getLoc()) << "target op";
2231  return diag;
2232  }
2233 
2234  // Convert the padding values to attributes.
2235  SmallVector<Attribute> paddingValues;
2236  for (auto const &[untypedAttr, elementOrTensorType] :
2237  llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2238  auto attr = dyn_cast<TypedAttr>(untypedAttr);
2239  Type elementType = getElementTypeOrSelf(elementOrTensorType);
2240 
2241  if (isa<ub::PoisonAttr>(untypedAttr)) {
2242  paddingValues.push_back(untypedAttr);
2243  continue;
2244  }
2245  if (!attr) {
2246  emitOpError("expects padding values to be typed attributes or poison");
2248  }
2249  // Try to parse string attributes to obtain an attribute of element type.
2250  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2251  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2252  stringAttr, getContext(), elementType,
2253  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2254  if (!parsedAttr || parsedAttr.getType() != elementType) {
2255  auto diag = this->emitOpError("expects a padding that parses to ")
2256  << elementType << ", got " << attr;
2257  diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2259  }
2260  paddingValues.push_back(parsedAttr);
2261  continue;
2262  }
2263  // Otherwise, add the attribute directly.
2264  if (attr.getType() != elementType) {
2265  auto diag = this->emitOpError("expects a padding value of type ")
2266  << elementType << ", got " << attr;
2267  diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2269  }
2270  paddingValues.push_back(attr);
2271  }
2272 
2273  // Set options.
2274  TilingInterface paddedOp;
2276  options.setPaddingValues(paddingValues)
2277  .setPaddingSizes(getMixedPaddingSizes())
2278  .setPadToMultipleOf(getPadToMultipleOf());
2279 
2280  // Apply padding.
2281  SmallVector<tensor::PadOp> newPadOps;
2282  FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
2283  rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
2284  newPadOps);
2285  if (failed(maybePaddedOp)) {
2286  auto diag = emitSilenceableError() << "failed to pad op";
2287  diag.attachNote(target->getLoc()) << "target op";
2288  return diag;
2289  }
2290 
2291  // Set transform results.
2292  paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
2293  padOps.append(newPadOps.begin(), newPadOps.end());
2294  }
2295 
2296  results.set(cast<OpResult>(getPadded()), paddedOps);
2297  results.set(cast<OpResult>(getPad()), padOps);
2299 }
2300 
2301 LogicalResult transform::PadTilingInterfaceOp::verify() { return success(); }
2302 
2303 //===---------------------------------------------------------------------===//
2304 // HoistPadOp
2305 //===---------------------------------------------------------------------===//
2306 
2307 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2308  transform::TransformRewriter &rewriter,
2309  transform::TransformResults &transformResults,
2310  transform::TransformState &state) {
2311  auto targetOps = state.getPayloadOps(getTarget());
2312  auto loopOps = state.getPayloadOps(getLoop());
2313  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2314  return emitDefiniteFailure()
2315  << "requires exactly one target and one loop handle (got "
2316  << llvm::range_size(targetOps) << " and "
2317  << llvm::range_size(loopOps) << ")";
2318  }
2319 
2320  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2321  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2322  if (!padOp || !loopOp)
2323  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2324 
2325  FailureOr<linalg::detail::PackingResult> result =
2326  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2327  getTranspose());
2328  if (failed(result))
2329  return emitDefiniteFailure() << "could not build packing loop nest";
2330 
2331  if (result->clonedLoopIvs.empty()) {
2332  transformResults.set(cast<OpResult>(getPackingLoop()),
2333  {result->hoistedPadOp.getOperation()});
2335  }
2336  auto outerPackedLoop =
2337  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2338  transformResults.set(cast<OpResult>(getPackingLoop()),
2339  {outerPackedLoop.getOperation()});
2341 }
2342 
2344  ArrayRef<int64_t> transpose = getTranspose();
2345  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2346  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2347  transpose.end())) {
2348  return emitOpError() << "expects transpose to be a permutation, found "
2349  << getTranspose();
2350  }
2351  return success();
2352 }
2353 
2354 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2356  transform::onlyReadsHandle(getTargetMutable(), effects);
2357  transform::onlyReadsHandle(getLoopMutable(), effects);
2358  transform::producesHandle(getOperation()->getOpResults(), effects);
2359  transform::modifiesPayload(effects);
2360 }
2361 
2363 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2364  tensor::PadOp target,
2366  transform::TransformState &state) {
2367  tensor::PadOp hoistedPadOp;
2368  SmallVector<TransposeOp> transposeOps;
2369  FailureOr<Value> result =
2370  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2371  hoistedPadOp, transposeOps);
2372  if (succeeded(result)) {
2373  // We need to perform our own replacement here because this API is still
2374  // used in patterns that "pad and hoist", for which the replacement values
2375  // need to be different.
2376  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2377  // that we have more composable abstractions.
2378  rewriter.replaceOp(target, *result);
2379  results.push_back(hoistedPadOp);
2381  }
2382  return emitDefaultSilenceableFailure(target);
2383 }
2384 
2385 LogicalResult transform::HoistPadOp::verify() {
2386  ArrayRef<int64_t> transpose = getTranspose();
2387  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2388  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2389  transpose.end())) {
2390  return emitOpError() << "expects transpose to be a permutation, found "
2391  << getTranspose();
2392  }
2393  return success();
2394 }
2395 
2396 //===----------------------------------------------------------------------===//
2397 // PromoteOp
2398 //===----------------------------------------------------------------------===//
2399 
2401 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2402  LinalgOp target,
2404  transform::TransformState &state) {
2405  LinalgPromotionOptions promotionOptions;
2406  if (!getOperandsToPromote().empty())
2407  promotionOptions = promotionOptions.setOperandsToPromote(
2408  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2409  if (getUseFullTilesByDefault())
2410  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2411  getUseFullTilesByDefault());
2412  if (getUseOriginalSubviewSize())
2413  promotionOptions =
2414  promotionOptions.setUseOriginalSubviewSize(getUseOriginalSubviewSize());
2415  if (getUseAlloca())
2416  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2417  if (!getUseFullTileBuffers().empty())
2418  promotionOptions = promotionOptions.setUseFullTileBuffers(
2419  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2420  if (getAlignment().has_value())
2421  promotionOptions = promotionOptions.setAlignment(*getAlignment());
2422  if (getMemorySpace().has_value())
2423  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2424 
2425  if (getMapping().has_value()) {
2426  // The mapping should only contain an element
2427  auto mapping = *getMapping();
2428  if (mapping.size() > 1)
2429  return emitDefaultDefiniteFailure(target);
2430 
2431  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2432 
2433  if (addressSpace.getAddressSpace() ==
2434  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2435  promotionOptions =
2436  promotionOptions
2440  .setUseFullTileBuffers({false, false});
2441  } else if (addressSpace.getAddressSpace() ==
2442  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2443  promotionOptions =
2444  promotionOptions
2448  .setUseFullTileBuffers({false, false});
2449  } else {
2450  return emitDefaultDefiniteFailure(target);
2451  }
2452  }
2453 
2454  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2455  return emitDefaultDefiniteFailure(target);
2456 
2457  rewriter.setInsertionPoint(target);
2458  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2459  if (failed(res))
2460  return emitDefaultDefiniteFailure(target);
2461  results.push_back(target);
2463 }
2464 
2465 //===----------------------------------------------------------------------===//
2466 // ReplaceOp
2467 //===----------------------------------------------------------------------===//
2468 
2470 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2471  TransformResults &transformResults,
2472  TransformState &state) {
2473  auto payload = state.getPayloadOps(getTarget());
2474 
2475  // Check for invalid targets.
2476  for (Operation *target : payload) {
2477  if (target->getNumOperands() > 0)
2478  return emitDefiniteFailure() << "expected target without operands";
2479  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2480  target->getNumRegions() > 0)
2481  return emitDefiniteFailure()
2482  << "expected target that is isolated from above";
2483  }
2484 
2485  // Clone and replace.
2486  Operation *pattern = &getBodyRegion().front().front();
2487  SmallVector<Operation *> replacements;
2488  for (Operation *target : payload) {
2489  if (getOperation()->isAncestor(target))
2490  continue;
2491  rewriter.setInsertionPoint(target);
2492  Operation *replacement = rewriter.clone(*pattern);
2493  rewriter.replaceOp(target, replacement->getResults());
2494  replacements.push_back(replacement);
2495  }
2496  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2498 }
2499 
2500 void transform::ReplaceOp::getEffects(
2502  consumesHandle(getTargetMutable(), effects);
2503  producesHandle(getOperation()->getOpResults(), effects);
2504  modifiesPayload(effects);
2505 }
2506 
2507 LogicalResult transform::ReplaceOp::verify() {
2508  if (!getBodyRegion().hasOneBlock())
2509  return emitOpError() << "expected one block";
2510  if (std::distance(getBodyRegion().front().begin(),
2511  getBodyRegion().front().end()) != 1)
2512  return emitOpError() << "expected one operation in block";
2513  Operation *replacement = &getBodyRegion().front().front();
2514  if (replacement->getNumOperands() > 0)
2515  return replacement->emitOpError()
2516  << "expected replacement without operands";
2517  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2518  replacement->getNumRegions() > 0)
2519  return replacement->emitOpError()
2520  << "expect op that is isolated from above";
2521  return success();
2522 }
2523 
2524 //===----------------------------------------------------------------------===//
2525 // ScalarizeOp
2526 //===----------------------------------------------------------------------===//
2527 
2529 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2530  LinalgOp target,
2532  transform::TransformState &state) {
2533  scf::SCFTilingOptions tilingOptions;
2534  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2535  SmallVector<OpFoldResult> tileSizes;
2536  Location loc = target.getLoc();
2537  SmallVector<OpFoldResult> allShapeSizes =
2538  target.createFlatListOfOperandDims(b, loc);
2539  AffineMap map = target.getShapesToLoopsMap();
2540  if (!map)
2541  return tileSizes;
2542  SmallVector<OpFoldResult> shapeSizes =
2544  allShapeSizes);
2545  // If the shape size is dynamic, tile by 1.
2546  // Otherwise, do not tile (i.e. tile size 0).
2547  for (OpFoldResult shapeSize : shapeSizes) {
2548  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2549  : b.getIndexAttr(1));
2550  }
2551  return tileSizes;
2552  });
2553  rewriter.setInsertionPoint(target);
2554  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2555  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2556  if (failed(maybeTilingResult))
2557  return emitDefaultDefiniteFailure(target);
2558 
2559  if (target->getNumResults())
2560  rewriter.replaceOp(target, maybeTilingResult->replacements);
2561  else
2562  rewriter.eraseOp(target);
2563 
2564  results.reserve(maybeTilingResult->tiledOps.size());
2565  for (Operation *tiled : maybeTilingResult->tiledOps)
2566  results.push_back(tiled);
2568 }
2569 
2570 //===----------------------------------------------------------------------===//
2571 // ConvertToLoopsOp
2572 //===----------------------------------------------------------------------===//
2573 
2575 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2576  transform::TransformResults &results,
2577  transform::TransformState &state) {
2579  for (Operation *target : state.getPayloadOps(getTarget())) {
2580  auto tilingOp = dyn_cast<TilingInterface>(*target);
2581  if (!tilingOp) {
2583  emitSilenceableError()
2584  << "expected the payload to implement TilingInterface";
2585  diag.attachNote(target->getLoc()) << "payload op";
2586  return diag;
2587  }
2588  rewriter.setInsertionPoint(target);
2589  FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2590  scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2591  if (failed(generatedLoops))
2592  return emitDefaultDefiniteFailure(target);
2593  for (scf::ForOp &loop : *generatedLoops) {
2594  loops.push_back(loop.getOperation());
2595  }
2596  rewriter.eraseOp(target);
2597  }
2598  results.set(cast<OpResult>(getResult()), loops);
2600 }
2601 
2602 //===----------------------------------------------------------------------===//
2603 // RewriteInDestinationPassingStyleOp
2604 //===----------------------------------------------------------------------===//
2605 
2607 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2608  transform::TransformRewriter &rewriter, Operation *target,
2610  transform::TransformState &state) {
2611  rewriter.setInsertionPoint(target);
2612  FailureOr<Operation *> maybeResult =
2614  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2615  [&rewriter](auto op) {
2616  return rewriteInDestinationPassingStyle(rewriter, op);
2617  });
2618  if (failed(maybeResult))
2619  return emitDefaultSilenceableFailure(target);
2620  results.push_back(*maybeResult);
2622 }
2623 
2624 //===----------------------------------------------------------------------===//
2625 // SplitOp
2626 //===----------------------------------------------------------------------===//
2627 
2629 SplitOp::apply(transform::TransformRewriter &rewriter,
2630  TransformResults &results, TransformState &state) {
2631  // Collect the dynamic split points if provided.
2632  SmallVector<Operation *> payload =
2633  llvm::to_vector(state.getPayloadOps(getTarget()));
2634 
2635  bool isMultiwaySplit = getMultiway();
2636 
2637  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2638  return mlir::emitSilenceableFailure(getLoc())
2639  << "requires exactly one target when "
2640  "multiway split is enabled (got "
2641  << llvm::range_size(payload) << ")";
2642  }
2643 
2644  SmallVector<OpFoldResult> chunkSizes;
2645 
2646  if (!isMultiwaySplit)
2647  chunkSizes.reserve(payload.size());
2648 
2649  if (getDynamicChunkSizes()) {
2651  if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2652  chunkSizes = llvm::to_vector(llvm::map_range(
2653  state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2654  if (op->getNumResults() != 1 ||
2655  !op->getResult(0).getType().isIndex()) {
2656  diag = emitSilenceableError()
2657  << "expected dynamic split point handle to point to a "
2658  "single-result index-typed op";
2659  diag.attachNote(op->getLoc()) << "dynamic split point";
2660  }
2661  return OpFoldResult(op->getResult(0));
2662  }));
2663  } else {
2664  chunkSizes = llvm::to_vector(
2665  llvm::map_range(state.getParams(getDynamicChunkSizes()),
2666  [](Attribute attr) { return OpFoldResult(attr); }));
2667  }
2668  if (diag.isSilenceableFailure())
2669  return diag;
2670 
2671  // For multiway split, a single payload is expected to have multiple
2672  // split points.
2673  if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2674  return emitDefiniteFailure()
2675  << "expected the dynamic split point handle to point to as "
2676  "many operations ("
2677  << chunkSizes.size() << ") as the target handle ("
2678  << payload.size() << ")";
2679  }
2680  } else {
2681  chunkSizes.resize(payload.size(),
2682  rewriter.getIndexAttr(getStaticChunkSizes()));
2683  }
2684 
2685  auto checkStructuredOpAndDimensions =
2686  [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2687  if (!linalgOp) {
2688  auto diag = emitSilenceableError() << "only applies to structured ops";
2689  diag.attachNote(loc) << "target op";
2690  return diag;
2691  }
2692 
2693  if (getDimension() >= linalgOp.getNumLoops()) {
2694  auto diag = emitSilenceableError() << "dimension " << getDimension()
2695  << " does not exist in target op";
2696  diag.attachNote(loc) << "target op";
2697  return diag;
2698  }
2700  };
2701 
2702  auto checkFailureInSplitting =
2703  [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2704  if (hasFailed) {
2705  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2706  diag.attachNote(loc) << "target op";
2707  return diag;
2708  }
2710  };
2711 
2712  SmallVector<Operation *> opList;
2713  if (isMultiwaySplit) {
2714 
2715  // Split a single target operation at multiple points.
2716  TilingInterface head, tail;
2717  Operation *target = payload.front();
2718 
2719  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2720 
2721  // Check that the target is a valid LinalgOp with correct dimensions.
2723  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2724  if (diag.isSilenceableFailure())
2725  return diag;
2726 
2727  for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2728 
2729  if (idx > 0)
2730  target = tail.getOperation();
2731 
2732  if (!target)
2733  break;
2734 
2735  linalgOp = cast<LinalgOp>(target);
2736  Location loc = target->getLoc();
2737 
2738  rewriter.setInsertionPoint(linalgOp);
2739  std::tie(head, tail) = linalg::splitOp(
2740  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2741  getDimension(), chunkSize);
2742 
2743  // Propagate errors.
2745  checkFailureInSplitting(!head && !tail, loc);
2746  if (diag.isDefiniteFailure())
2747  return diag;
2748 
2749  opList.push_back(head.getOperation());
2750  }
2751 
2752  // Append any leftover parts to the end of the result list.
2753  if (tail)
2754  opList.push_back(tail.getOperation());
2755 
2756  } else {
2757  // Split each target operation.
2758  SmallVector<Operation *> first, second;
2759  Operation *noSecondPart = nullptr;
2760  for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2761  Operation *target = std::get<0>(pair);
2762  Location loc = target->getLoc();
2763  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2765  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2766 
2767  if (diag.isSilenceableFailure())
2768  return diag;
2769 
2770  rewriter.setInsertionPoint(linalgOp);
2771  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2772  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2773  getDimension(), std::get<1>(pair));
2774 
2775  // Propagate errors.
2776  DiagnosedSilenceableFailure diagSplit =
2777  checkFailureInSplitting(!first.back() && !second.back(), loc);
2778  if (diagSplit.isDefiniteFailure())
2779  return diag;
2780 
2781  // Do not add null second parts.
2782  if (!second.back()) {
2783  noSecondPart = target;
2784  second.pop_back();
2785  }
2786  }
2787 
2788  if (second.size() != first.size() && !second.empty()) {
2789  auto diag = emitSilenceableError()
2790  << "splitting does not produce the second part for a subset "
2791  "of targets";
2792  diag.attachNote()
2793  << "expected splitting to produce the second part of all "
2794  "or none of the targets";
2795  diag.attachNote(noSecondPart->getLoc())
2796  << "first target with no second part";
2797  return diag;
2798  }
2799 
2800  opList.append(first);
2801  if (second.size())
2802  opList.append(second);
2803  }
2804  results.set(cast<OpResult>(getSplitList()), opList);
2806 }
2807 
2808 void SplitOp::getEffects(
2810  consumesHandle(getTargetMutable(), effects);
2811  if (getDynamicChunkSizes())
2812  onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
2813  producesHandle(getOperation()->getOpResults(), effects);
2814  modifiesPayload(effects);
2815 }
2816 
2817 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2818  OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
2819  IntegerAttr staticChunkSizes;
2820  if (parser.parseOperand(target) || parser.parseKeyword("after"))
2821  return failure();
2822 
2823  OptionalParseResult dynamicPointParseResult =
2824  parser.parseOptionalOperand(dynamicChunkSizes);
2825  if (!dynamicPointParseResult.has_value()) {
2826  int64_t staticChunkSizesValue;
2827  if (failed(parser.parseInteger(staticChunkSizesValue)))
2828  return failure();
2829 
2830  staticChunkSizes =
2831  parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
2832  }
2833 
2834  Type targetType;
2835  if (parser.parseOptionalAttrDict(result.attributes) ||
2836  parser.parseColonType(targetType) ||
2837  parser.resolveOperand(target, targetType, result.operands)) {
2838  return failure();
2839  }
2840  if (dynamicPointParseResult.has_value()) {
2841  Type ChunkSizesType;
2842  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2843  parser.parseType(ChunkSizesType) ||
2844  parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
2845  result.operands)) {
2846  return failure();
2847  }
2848 
2849  staticChunkSizes =
2850  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2851  }
2852 
2853  result.addAttribute(
2854  SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
2855  staticChunkSizes);
2856  result.addTypes(targetType);
2857  return success();
2858 }
2859 
2860 void SplitOp::print(OpAsmPrinter &printer) {
2861  printer << " " << getTarget() << " after ";
2862  int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2863  if (staticChunkSize != ShapedType::kDynamic)
2864  printer << staticChunkSize;
2865  else
2866  printer << getDynamicChunkSizes();
2867  printer << " ";
2868  printer.printOptionalAttrDict(getOperation()->getAttrs(),
2869  {getStaticChunkSizesAttrName()});
2870  printer << " : " << getTarget().getType();
2871  if (staticChunkSize == ShapedType::kDynamic)
2872  printer << ", " << getDynamicChunkSizes().getType();
2873 }
2874 
2875 LogicalResult SplitOp::verify() {
2876  if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2877  (getDynamicChunkSizes() == nullptr)) {
2878  return emitOpError() << "expects either a dynamic or a static split "
2879  "point to be provided";
2880  }
2881  return success();
2882 }
2883 
2884 //===----------------------------------------------------------------------===//
2885 // SplitReductionOp
2886 //===----------------------------------------------------------------------===//
2887 
2888 void transform::SplitReductionOp::build(
2889  OpBuilder &builder, OperationState &result, Value target,
2890  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2891  bool useScalingAlgorithm, bool useAlloc) {
2892  MLIRContext *ctx = builder.getContext();
2893  result.addOperands(target);
2894  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2895  builder.getI64IntegerAttr(splitFactor));
2896  result.addAttribute(
2897  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2898  builder.getI64IntegerAttr(insertSplitDimension));
2899  if (innerParallel) {
2900  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2901  builder.getUnitAttr());
2902  }
2903  if (useScalingAlgorithm) {
2904  result.addAttribute(
2905  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2906  builder.getUnitAttr());
2907  }
2908  if (useAlloc) {
2909  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2910  builder.getUnitAttr());
2911  }
2912  auto resultType = transform::AnyOpType::get(ctx);
2913  result.addTypes({resultType, resultType, resultType, resultType});
2914 }
2915 
2916 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2917  transform::TransformRewriter &rewriter, LinalgOp target,
2919  transform::TransformState &state) {
2920  ControlSplitReductionFn splitFn = [&](LinalgOp) {
2921  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2922  unsigned(getInsertSplitDimension()),
2923  bool(getInnerParallel())};
2924  };
2925  rewriter.setInsertionPoint(target);
2926  FailureOr<SplitReductionResult> splitResult =
2927  (getUseScalingAlgorithm())
2928  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2929  : splitReduction(rewriter, target, splitFn, getUseAlloc());
2930  if (failed(splitResult))
2931  return emitDefaultDefiniteFailure(target);
2932 
2933  results.push_back(splitResult->initOrAlloc);
2934  results.push_back(splitResult->fillOp);
2935  results.push_back(splitResult->splitLinalgOp);
2936  results.push_back(splitResult->resultCombiningLinalgOp);
2938 }
2939 
2940 //===----------------------------------------------------------------------===//
2941 // TileReductionUsingForOp
2942 //===----------------------------------------------------------------------===//
2943 
2944 void transform::TileReductionUsingForOp::build(
2945  OpBuilder &builder, OperationState &result, Value target,
2946  ArrayRef<int64_t> staticTileSizes) {
2947  // Call the default builder.
2948  // This is future-proof re mixed static-dynamic and setting up the proper
2949  // operands segment sizes attributes for multiple variadic operands.
2950  // In the absence of this, horrible bugs ensue.
2951  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2952  MLIRContext *ctx = builder.getContext();
2953  auto opTy = transform::AnyOpType::get(ctx);
2954  auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
2955  build(builder, result,
2956  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2957  /*target=*/target,
2958  /*reduction_dims=*/nullptr,
2959  /*tile_sizes=*/staticTileSizesAttr);
2960 }
2961 
2962 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2963  transform::TransformRewriter &rewriter, Operation *target,
2965  transform::TransformState &state) {
2966  rewriter.setInsertionPoint(target);
2967 
2968  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2969  if (!partialReductionOp) {
2970  return emitSilenceableFailure(
2971  target->getLoc(),
2972  "Operation should implement PartialReductionOpInterface");
2973  }
2974 
2975  SmallVector<unsigned> reductionDims =
2976  extractFromIntegerArrayAttr<unsigned>(getReductionDims());
2977  if (reductionDims.empty()) {
2978  for (auto [idx, iteratorType] :
2979  llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
2980  if (iteratorType == utils::IteratorType::reduction)
2981  reductionDims.push_back(idx);
2982  }
2983  }
2984 
2987  options.setReductionTilingStrategy(
2989  options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
2990  options.setReductionDims(reductionDims);
2991  FailureOr<scf::SCFTilingResult> result =
2992  scf::tileUsingSCF(rewriter, partialReductionOp, options);
2993 
2994  if (failed(result)) {
2995  return emitSilenceableFailure(getLoc(),
2996  "failed to tile using partial reduction");
2997  }
2998  rewriter.replaceOp(target, result->replacements);
2999  for (Value initValue : result->initialValues)
3000  results.push_back(initValue.getDefiningOp());
3001  for (auto parallelTiledOp : result->tiledOps)
3002  results.push_back(parallelTiledOp);
3003  for (auto mergeOp : result->mergeOps)
3004  results.push_back(mergeOp);
3005  results.push_back(result->loops.front());
3007 }
3008 
3009 //===----------------------------------------------------------------------===//
3010 // TileReductionUsingForallOp
3011 //===----------------------------------------------------------------------===//
3012 
3013 void transform::TileReductionUsingForallOp::build(
3014  OpBuilder &builder, OperationState &result, Value target,
3015  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3016  ArrayAttr mapping) {
3017  // Call the default builder.
3018  // This is future-proof re mixed static-dynamic and setting up the proper
3019  // operands segment sizes attributes for multiple variadic operands.
3020  // In the absence of this, horrible bugs ensue.
3021  // TODO: support mixed static-dynamic (see TileUsingForallOp).
3022  MLIRContext *ctx = builder.getContext();
3023  auto opTy = transform::AnyOpType::get(ctx);
3024  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3025  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3026  build(builder, result,
3027  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3028  /*target=*/target,
3029  /*reduction_dims=*/{},
3030  /*num_threads=*/staticNumThreadsAttr,
3031  /*tile_sizes=*/staticTileSizesAttr,
3032  /*mapping=*/mapping);
3033 }
3034 
3035 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3036  transform::TransformRewriter &rewriter, LinalgOp target,
3038  transform::TransformState &state) {
3039  rewriter.setInsertionPoint(target);
3040  SmallVector<OpFoldResult> numThreads =
3041  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
3042  SmallVector<OpFoldResult> tileSizes =
3044 
3047  options.setReductionTilingStrategy(
3049  if (!getNumThreads().empty()) {
3050  options.setNumThreads(numThreads);
3051  } else {
3052  options.setTileSizes(tileSizes);
3053  }
3054  if (auto mapping = getMapping()) {
3055  options.setMapping(mapping.value().getValue());
3056  }
3057  SmallVector<unsigned> reductionDims =
3058  extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3059  if (reductionDims.empty()) {
3060  for (auto [idx, iteratorType] :
3061  llvm::enumerate(target.getIteratorTypesArray())) {
3062  if (iteratorType == utils::IteratorType::reduction)
3063  reductionDims.push_back(idx);
3064  }
3065  }
3066  options.setReductionDims(reductionDims);
3067  FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
3068  rewriter, cast<TilingInterface>(target.getOperation()), options);
3069 
3070  if (failed(result)) {
3071  auto diag = emitSilenceableError() << "could not tile reduction";
3072  return diag;
3073  }
3074  rewriter.replaceOp(target, result->replacements);
3075 
3076  for (Value initValue : result->initialValues)
3077  results.push_back(initValue.getDefiningOp());
3078  for (auto parallelTiledOp : result->tiledOps)
3079  results.push_back(parallelTiledOp);
3080  for (auto mergeOp : result->mergeOps)
3081  results.push_back(mergeOp);
3082  results.push_back(result->loops.front());
3084 }
3085 
3086 //===----------------------------------------------------------------------===//
3087 // ContinuousTileSizesOp
3088 //===----------------------------------------------------------------------===//
3089 
3091 transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3092  TransformResults &transformResults,
3093  TransformState &state) {
3094 
3095  SmallVector<Operation *> targetOps =
3096  llvm::to_vector(state.getPayloadOps(getTarget()));
3097 
3098  if (!llvm::hasSingleElement(targetOps)) {
3099  return mlir::emitSilenceableFailure(getLoc())
3100  << "requires exactly one target (got " << llvm::range_size(targetOps)
3101  << ")";
3102  }
3103 
3104  Operation *target = *targetOps.begin();
3105  auto linalgOp = dyn_cast<LinalgOp>(target);
3106  auto tileableOp = dyn_cast<TilingInterface>(target);
3107 
3108  if (!linalgOp)
3109  return emitDefiniteFailure() << "expected Linalg Op";
3110 
3111  OpBuilder builder(linalgOp.getContext());
3112 
3113  if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
3114  if (linalgOp.hasDynamicShape()) {
3115  auto diag = emitSilenceableError()
3116  << "cannot compute parametric tile sizes for dynamically "
3117  "shaped payload op";
3118  diag.attachNote(linalgOp->getLoc()) << "payload op";
3119  return diag;
3120  }
3121 
3122  FailureOr<StaticContinuousTileSizeSpecification> spec =
3123  computeStaticContinuousTileSizes(linalgOp, getDimension(),
3124  getTargetSize());
3125  if (failed(spec)) {
3126  return emitSilenceableError()
3127  << "failed to compute multi-size tiling sizes";
3128  }
3129 
3130  SmallVector<int64_t> chunkSizes;
3131 
3132  for (auto &&[tileSize, tripCount] :
3133  llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3134  chunkSizes.push_back(tileSize * tripCount);
3135 
3136  auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3137  return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3138  return builder.getI64IntegerAttr(value);
3139  });
3140  };
3141  transformResults.setParams(cast<OpResult>(getTileSizes()),
3142  getI64AttrsFromI64(spec->tileSizes));
3143  transformResults.setParams(cast<OpResult>(getChunkSizes()),
3144  getI64AttrsFromI64(chunkSizes));
3145 
3147  }
3148 
3149  builder.setInsertionPoint(linalgOp);
3150 
3151  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
3152  unsigned dimension = getDimension();
3153 
3154  FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
3155  builder, tileableOp, dimension, targetSize, true);
3156  if (failed(spec)) {
3157  return emitSilenceableError() << "could not generate tile size computation";
3158  }
3159 
3160  AffineExpr s0 = builder.getAffineSymbolExpr(0);
3161  AffineExpr s1 = builder.getAffineSymbolExpr(1);
3162  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3163  return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
3164  ofrs);
3165  };
3166 
3167  SmallVector<Value> chunkSizes;
3168  Value splitPoint;
3169  for (auto &&[tileSize, tripCount] :
3170  llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3171  splitPoint = apply(s0 * s1, {tileSize, tripCount});
3172  chunkSizes.push_back(splitPoint);
3173  }
3174 
3175  auto getDefiningOps = [&](ArrayRef<Value> values) {
3176  return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3177  return value.getDefiningOp();
3178  });
3179  };
3180 
3181  transformResults.set(cast<OpResult>(getTileSizes()),
3182  getDefiningOps(spec->tileSizes));
3183  transformResults.set(cast<OpResult>(getChunkSizes()),
3184  getDefiningOps(chunkSizes));
3185 
3187 }
3188 
3190 
3191  if (getTileSizes().getType() != getChunkSizes().getType()) {
3192  return emitOpError() << "expects all results type to be the same";
3193  }
3194 
3195  return success();
3196 }
3197 
3198 void transform::ContinuousTileSizesOp::getEffects(
3200  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
3201  onlyReadsPayload(effects);
3202  else
3203  modifiesPayload(effects);
3204  onlyReadsHandle(getTargetMutable(), effects);
3205  producesHandle(getOperation()->getOpResults(), effects);
3206 }
3207 
3209  Type targetType, Type tile_sizes,
3210  Type) {
3211  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
3212 }
3213 
3214 static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
3215  Type &targetType,
3216  Type &tileSizesType,
3217  Type &chunkSizesType) {
3218  FunctionType funcType;
3219  llvm::SMLoc typeLoc = parser.getCurrentLocation();
3220  if (failed(parser.parseType<FunctionType>(funcType)))
3221  return failure();
3222 
3223  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3224  parser.emitError(typeLoc) << "expects a trailing functional type with one "
3225  "argument and one result";
3226  }
3227  targetType = funcType.getInput(0);
3228  tileSizesType = chunkSizesType = funcType.getResult(0);
3229 
3230  return success();
3231 }
3232 
3233 //===----------------------------------------------------------------------===//
3234 // TileUsingForOp
3235 //===----------------------------------------------------------------------===//
3236 
3237 void transform::TileUsingForOp::build(
3238  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3239  Value target, ArrayRef<int64_t> staticTileSizes,
3240  ArrayRef<int64_t> interchange,
3241  std::optional<ArrayRef<bool>> scalableSizes) {
3242  return build(builder, result, loopTypes,
3243  /*target=*/target,
3244  /*mixedTileSizes=*/
3245  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3246  interchange, scalableSizes);
3247 }
3248 
3249 void transform::TileUsingForOp::build(
3250  OpBuilder &builder, OperationState &result, Value target,
3251  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3252  std::optional<ArrayRef<bool>> scalableSizes) {
3253  build(builder, result, target,
3254  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3255  interchange, scalableSizes);
3256 }
3257 
3258 void transform::TileUsingForOp::build(
3259  OpBuilder &builder, OperationState &result, Value target,
3260  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3261  std::optional<ArrayRef<bool>> scalableSizes) {
3262  // Loop types are automaticaly splat by the callee, setting up one is
3263  // enough.
3264  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
3265  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3266  scalableSizes);
3267 }
3268 
3269 void transform::TileUsingForOp::build(
3270  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3271  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
3272  ArrayRef<int64_t> interchange,
3273  std::optional<ArrayRef<bool>> scalableSizes) {
3274  SmallVector<int64_t> staticTileSizes;
3275  SmallVector<Value> dynamicTileSizes;
3276  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3277  // Call the default builder which sets up the proper operands segment sizes
3278  // attributes for multiple variadic operands. In the absence of this,
3279  // horrible bugs ensue.
3280  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3281  unsigned numExpectedLoops =
3282  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3283  SmallVector<Type> resultTypes;
3284  resultTypes.reserve(numExpectedLoops);
3285  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3286  "expected one loop type or as many as loops");
3287  if (loopTypes.size() == 1)
3288  resultTypes.append(numExpectedLoops, loopTypes[0]);
3289  else
3290  llvm::append_range(resultTypes, loopTypes);
3291  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
3292  if (scalableSizes.has_value())
3293  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3294  build(builder, result, /*tiled_linalg_op=*/target.getType(),
3295  /*loops=*/resultTypes,
3296  /*target=*/target,
3297  /*dynamic_sizes=*/dynamicTileSizes,
3298  /*static_sizes=*/staticTileSizesAttr,
3299  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
3300  /*scalable_sizes=*/expandedScalableSizes);
3301 }
3302 
3303 LogicalResult transform::TileUsingForOp::verify() {
3304  if (getMixedSizes().size() != getScalableSizes().size())
3305  return emitOpError("expected same number of sizes (")
3306  << getMixedSizes().size() << ") and scalable sizes ("
3307  << getScalableSizes().size() << ")";
3308  ArrayRef<int64_t> staticSizes = getStaticSizes();
3309  unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3310  if (getLoops().size() != numExpectedLoops)
3311  return emitOpError("expected number of loops to tile (")
3312  << numExpectedLoops << ") to match number of `loops` results ("
3313  << getLoops().size() << ")";
3314  return success();
3315 }
3316 
3318 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3319  TransformResults &transformResults,
3320  TransformState &state) {
3321  ArrayRef<int64_t> tileSizes = getStaticSizes();
3322 
3323  SmallVector<Operation *> targets =
3324  llvm::to_vector(state.getPayloadOps(getTarget()));
3325  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3327  dynamicSizeProducers.reserve(getDynamicSizes().size());
3328  paramSizes.reserve(getDynamicSizes().size());
3329  for (Value transformValue : getDynamicSizes()) {
3330  if (isa<ParamType>(transformValue.getType())) {
3331  dynamicSizeProducers.push_back({});
3332  ArrayRef<Attribute> params = state.getParams(transformValue);
3333  paramSizes.push_back(
3334  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
3335  return cast<IntegerAttr>(attr).getValue().getSExtValue();
3336  })));
3337 
3338  if (paramSizes.back().size() != targets.size()) {
3340  emitSilenceableError()
3341  << "expected as many parameter values ("
3342  << dynamicSizeProducers.back().size() << ") as target ops ("
3343  << targets.size() << ")";
3344  diag.attachNote(transformValue.getLoc()) << "for this parameter";
3345  return diag;
3346  }
3347 
3348  continue;
3349  }
3350  paramSizes.push_back({});
3351  dynamicSizeProducers.push_back(
3352  llvm::to_vector(state.getPayloadOps(transformValue)));
3353 
3354  if (dynamicSizeProducers.back().size() != targets.size()) {
3356  emitSilenceableError()
3357  << "expected as many dynamic size-producing operations ("
3358  << dynamicSizeProducers.back().size() << ") as target ops ("
3359  << targets.size() << ")";
3360  diag.attachNote(transformValue.getLoc()) << "for this handle";
3361  return diag;
3362  }
3363 
3364  for (Operation *op : dynamicSizeProducers.back()) {
3365  if (op->getNumResults() == 1 &&
3366  isa<IndexType>(op->getResult(0).getType())) {
3367  continue;
3368  }
3369 
3371  emitSilenceableError() << "expected sizes to be produced by ops "
3372  "with a single index-type result";
3373  diag.attachNote(op->getLoc()) << "size producer op";
3374  diag.attachNote(transformValue.getLoc()) << "for this handle";
3375  return diag;
3376  }
3377  }
3378 
3381  loops.resize(getLoops().size());
3382  auto scalableSizes = getScalableSizes();
3383  for (auto [i, op] : llvm::enumerate(targets)) {
3384  auto tilingInterface = dyn_cast<TilingInterface>(op);
3385  if (!tilingInterface) {
3387  emitSilenceableError()
3388  << "only ops implementing TilingInterface are supported";
3389  diag.attachNote(op->getLoc()) << "target op";
3390  return diag;
3391  }
3392  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3394  emitSilenceableError()
3395  << "too many tiles provided, expected at most "
3396  << tilingInterface.getLoopIteratorTypes().size() << " found "
3397  << tileSizes.size();
3398  diag.attachNote(op->getLoc()) << "target op";
3399  return diag;
3400  }
3401 
3402  scf::SCFTilingOptions tilingOptions;
3403  if (tileSizes.empty()) {
3404  tilingOptions.setTileSizeComputationFunction(
3406  return {};
3407  });
3408  } else {
3409  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3410  Operation *) {
3412  sizes.reserve(tileSizes.size());
3413  unsigned dynamicIdx = 0;
3414 
3415  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3416  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3417  if (scalableSizes[ofrIdx]) {
3418  auto val = arith::ConstantIndexOp::create(
3419  b, getLoc(), cast<IntegerAttr>(attr).getInt());
3420  Value vscale =
3421  vector::VectorScaleOp::create(b, getLoc(), b.getIndexType());
3422  sizes.push_back(
3423  arith::MulIOp::create(b, getLoc(), val, vscale).getResult());
3424  } else {
3425  sizes.push_back(attr);
3426  }
3427  continue;
3428  }
3429  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3430  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3431  ++dynamicIdx;
3432  assert((dynamicSizes.empty() ^ params.empty()) &&
3433  "expected either dynamic sizes or parameters");
3434  if (!params.empty()) {
3435  sizes.push_back(b.getIndexAttr(params[index]));
3436  } else {
3437  sizes.push_back(dynamicSizes[index]->getResult(0));
3438  }
3439  }
3440  return sizes;
3441  });
3442  }
3443 
3444  tilingOptions.setInterchange(getInterchange());
3445  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3446  tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3447  if (failed(maybeTilingResult))
3449 
3450  rewriter.replaceOp(op, maybeTilingResult->replacements);
3451 
3452  tiled.append(maybeTilingResult->tiledOps);
3453  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3454  loops[en2.index()].push_back(en2.value());
3455  }
3456 
3457  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3458  for (const auto &en : llvm::enumerate(loops))
3459  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3460 
3462 }
3463 
3465  ValueRange dynamic = getDynamicSizes();
3466  ArrayRef<int64_t> tileSizes = getStaticSizes();
3467  SmallVector<OpFoldResult> results;
3468  results.reserve(tileSizes.size());
3469  unsigned dynamicPos = 0;
3470  Builder builder(getContext());
3471  for (int64_t size : tileSizes) {
3472  if (size == ShapedType::kDynamic) {
3473  results.push_back(dynamic[dynamicPos++]);
3474  } else {
3475  results.push_back(builder.getIndexAttr(size));
3476  }
3477  }
3478  return results;
3479 }
3480 
3481 void transform::TileUsingForOp::getEffects(
3483  consumesHandle(getTargetMutable(), effects);
3484  onlyReadsHandle(getDynamicSizesMutable(), effects);
3485  producesHandle(getOperation()->getOpResults(), effects);
3486  modifiesPayload(effects);
3487 }
3488 
3489 //===----------------------------------------------------------------------===//
3490 // TileUsingForallOp
3491 //===----------------------------------------------------------------------===//
3492 
3493 void transform::TileUsingForallOp::build(OpBuilder &builder,
3494  OperationState &result, Value target,
3495  ArrayRef<int64_t> staticTileSizes,
3497  ArrayAttr mapping) {
3498  return build(builder, result,
3499  /*target=*/target,
3500  /*mixedTileSizes=*/
3501  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3502  /*_=*/TileSizesSpec(),
3503  /*mapping=*/mapping);
3504 }
3505 
3506 void transform::TileUsingForallOp::build(OpBuilder &builder,
3507  OperationState &result, Value target,
3508  ArrayRef<OpFoldResult> mixedTileSizes,
3510  ArrayAttr mapping) {
3511  SmallVector<int64_t> staticTileSizes;
3512  SmallVector<Value> dynamicTileSizes;
3513  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3514  // Call the default builder which sets up the proper operands segment sizes
3515  // attributes for multiple variadic operands. In the absence of this,
3516  // horrible bugs ensue.
3517  MLIRContext *ctx = builder.getContext();
3518  auto operationType = transform::AnyOpType::get(ctx);
3519  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3520  build(builder, result,
3521  /*resultTypes=*/TypeRange{operationType, operationType},
3522  /*target=*/target,
3523  /*num_threads=*/ValueRange{},
3524  /*tile_sizes=*/dynamicTileSizes,
3525  /*packed_num_threads=*/Value(),
3526  /*packed_tile_sizes=*/Value(),
3527  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3528  /*static_tile_sizes=*/staticTileSizesAttr,
3529  /*mapping=*/mapping);
3530 }
3531 
3532 void transform::TileUsingForallOp::build(OpBuilder &builder,
3533  OperationState &result, Value target,
3534  ArrayRef<int64_t> staticNumThreads,
3536  ArrayAttr mapping) {
3537  return build(builder, result, target,
3538  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3539  NumThreadsSpec(), mapping);
3540 }
3541 
3542 void transform::TileUsingForallOp::build(OpBuilder &builder,
3543  OperationState &result, Value target,
3544  ArrayRef<OpFoldResult> mixedNumThreads,
3546  ArrayAttr mapping) {
3547  SmallVector<int64_t> staticNumThreads;
3548  SmallVector<Value> dynamicNumThreads;
3549  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3550  staticNumThreads);
3551  // Call the default builder which sets up the proper operands segment sizes
3552  // attributes for multiple variadic operands. In the absence of this,
3553  // horrible bugs ensue.
3554  MLIRContext *ctx = builder.getContext();
3555  auto operationType = transform::AnyOpType::get(ctx);
3556  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3557  build(builder, result,
3558  /*resultTypes=*/TypeRange{operationType, operationType},
3559  /*target=*/target,
3560  /*num_threads=*/dynamicNumThreads,
3561  /*tile_sizes=*/ValueRange{},
3562  /*packed_num_threads=*/Value(),
3563  /*packed_tile_sizes=*/Value(),
3564  /*static_num_threads=*/staticNumThreadsAttr,
3565  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3566  /*mapping=*/mapping);
3567 }
3568 
3569 /// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3570 /// normalized upper bound.
3574  ArrayRef<OpFoldResult> steps) {
3575  AffineExpr s0, s1, s2;
3576  bindSymbols(rewriter.getContext(), s0, s1, s2);
3577  AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3578  SmallVector<OpFoldResult> normalizedUbs;
3579  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3581  rewriter, loc, normalizedUbExpr, {lb, ub, step});
3582  normalizedUbs.push_back(normalizedUb);
3583  }
3584  return normalizedUbs;
3585 }
3586 
3587 /// When a loop is normalized, the uses of the induction variable within the
3588 /// loop need to replaced with `original_lb + old_iv * original_step`.
3590  Location loc, ValueRange ivs,
3592  ArrayRef<OpFoldResult> steps) {
3593  AffineExpr s0, s1;
3594  AffineExpr d0;
3595  bindSymbols(rewriter.getContext(), s0, s1);
3596  bindDims(rewriter.getContext(), d0);
3597  AffineExpr denormExpr = s0 + d0 * s1;
3598  SmallVector<Value> denormalizedIvs;
3599 
3600  for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3602  rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3603  denormalizedIvs.push_back(
3604  getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3605  }
3606  return denormalizedIvs;
3607 }
3608 
3609 /// Given a `scf.forall` loop return a loop op with the loop bounds
3610 /// normalized.
3611 /// TODO: Replace this with a general utility to normalize `scf.forall`.
3612 /// At the time of writing, this wasnt done since adding this to `scf`
3613 /// dialect would disallow using of `affine.apply` operations due
3614 /// to cyclic dependencies. To avoid churn in lit tests
3615 /// with the change this was added with, defer that to a follow up.
3616 static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3617  scf::ForallOp loop) {
3618  SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3619  SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3620  SmallVector<OpFoldResult> steps = loop.getMixedStep();
3621 
3622  if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
3623  return loop;
3624  }
3625 
3626  Location loc = loop.getLoc();
3627  SmallVector<OpFoldResult> normalizedUbs =
3628  normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3629  SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3630  rewriter.getIndexAttr(0));
3631  SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3632  rewriter.getIndexAttr(1));
3633 
3634  auto normalizedForallOp = scf::ForallOp::create(
3635  rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3636  loop.getOutputs(), loop.getMapping(),
3637  [](OpBuilder &, Location, ValueRange) {});
3638 
3639  auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3640  OpBuilder::InsertionGuard g(rewriter);
3641  Block *normalizedLoopBlock = normalizedForallOp.getBody();
3642  rewriter.setInsertionPointToStart(normalizedLoopBlock);
3643 
3644  SmallVector<Value> argValues =
3645  denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3646  argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3647  normalizedForallOp.getRegionIterArgs().end());
3648  Block *origLoopBlock = loop.getBody();
3649  rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3650 
3651  rewriter.replaceOp(loop, normalizedForallOp);
3652  return normalizedForallOp;
3653 }
3654 
3656  RewriterBase &rewriter, transform::TransformState &state,
3657  TransformOpInterface transformOp, Operation *target,
3658  ArrayRef<OpFoldResult> mixedNumThreads,
3659  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3660  scf::SCFTilingResult &tilingResult) {
3661  // Transform all targets one by one.
3662  auto tileableOp = dyn_cast<TilingInterface>(target);
3663  if (!tileableOp) {
3665  transformOp.emitSilenceableError()
3666  << "only TilingInterface ops are supported";
3667  diag.attachNote(target->getLoc()) << "target op";
3668  return diag;
3669  }
3670  rewriter.setInsertionPoint(tileableOp);
3673  if (!mixedNumThreads.empty()) {
3674  options.setNumThreads(mixedNumThreads);
3675  } else {
3676  options.setTileSizes(mixedTileSizes);
3677  }
3678  if (mapping) {
3679  options.setMapping(mapping.value().getValue());
3680  }
3681  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3682  scf::tileUsingSCF(rewriter, tileableOp, options);
3683 
3684  if (failed(maybeTilingResult))
3685  return transformOp.emitDefaultSilenceableFailure(tileableOp);
3686 
3687  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3688 
3689  tilingResult = *maybeTilingResult;
3690 
3691  if (mixedNumThreads.empty()) {
3692  auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3693  OpBuilder::InsertionGuard g(rewriter);
3694  rewriter.setInsertionPoint(generatedForallOp);
3695  scf::ForallOp normalizedForallOp =
3696  normalizeForallLoopOp(rewriter, generatedForallOp);
3697  tilingResult.loops.front() = normalizedForallOp;
3698  }
3699 
3701 }
3702 
3703 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3704  transform::TransformRewriter &rewriter,
3705  transform::TransformResults &transformResults,
3706  transform::TransformState &state) {
3707  auto transformOp = cast<TransformOpInterface>(getOperation());
3708 
3709  // Result payload ops.
3710  SmallVector<Operation *> tileOps;
3711  SmallVector<Operation *> tiledOps;
3712 
3713  // Unpack handles.
3714  SmallVector<OpFoldResult> mixedNumThreads;
3716  getPackedNumThreads()
3718  state, transformOp, mixedNumThreads, getPackedNumThreads())
3720  state, transformOp, mixedNumThreads, getMixedNumThreads());
3721  if (!status.succeeded())
3722  return status;
3723  SmallVector<OpFoldResult> mixedTileSizes;
3724  status = getPackedTileSizes()
3726  state, transformOp, mixedTileSizes, getPackedTileSizes())
3728  state, transformOp, mixedTileSizes, getMixedTileSizes());
3729  if (!status.succeeded())
3730  return status;
3731 
3732  for (Operation *target : state.getPayloadOps(getTarget())) {
3733  scf::SCFTilingResult tilingResult;
3735  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3736  getMapping(), tilingResult);
3737  if (!diag.succeeded())
3738  return diag;
3739  tileOps.push_back(tilingResult.loops.front());
3740  tiledOps.append(tilingResult.tiledOps);
3741  }
3742 
3743  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3744  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3745 
3747 }
3748 
3749 void transform::TileUsingForallOp::getEffects(
3751  consumesHandle(getTargetMutable(), effects);
3752  onlyReadsHandle(getTileSizesMutable(), effects);
3753  onlyReadsHandle(getNumThreadsMutable(), effects);
3754  onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3755  onlyReadsHandle(getPackedTileSizesMutable(), effects);
3756  producesHandle(getOperation()->getOpResults(), effects);
3757  modifiesPayload(effects);
3758 }
3759 
3760 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3761  Builder b(getContext());
3762  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3763 }
3764 
3765 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3766  Builder b(getContext());
3767  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3768 }
3769 
3770 LogicalResult TileUsingForallOp::verify() {
3771  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3772  static_cast<int>(getPackedNumThreads() != Value());
3773  if (numThreadsSpec > 1)
3774  return emitOpError(
3775  "num_threads and packed_num_threads are mutually exclusive");
3776  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3777  static_cast<int>(getPackedTileSizes() != Value());
3778  if (tileSizesSpec > 1)
3779  return emitOpError(
3780  "tile_sizes and packed_tile_sizes are mutually exclusive");
3781  if (numThreadsSpec == 0 && tileSizesSpec == 0)
3782  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3783  "must be specified");
3784  return success();
3785 }
3786 
3787 //===----------------------------------------------------------------------===//
3788 // VectorizeChildrenAndApplyPatternsOp
3789 //===----------------------------------------------------------------------===//
3790 
3791 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3792  OpBuilder &builder, OperationState &result, Value target,
3793  bool foldTypeExtensionsIntoContract, bool vectorizePadding,
3794  bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3795  result.addOperands(target);
3796  if (foldTypeExtensionsIntoContract) {
3797  result.addAttribute(
3798  VectorizeChildrenAndApplyPatternsOp::
3799  getFoldTypeExtensionsIntoContractAttrName(result.name),
3800  builder.getUnitAttr());
3801  }
3802  if (vectorizePadding) {
3803  result.addAttribute(
3804  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3805  result.name),
3806  builder.getUnitAttr());
3807  }
3808  if (vectorizeExtract) {
3809  result.addAttribute(
3810  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3811  result.name),
3812  builder.getUnitAttr());
3813  }
3814  if (flatten1DDepthwiseConv) {
3815  result.addAttribute(
3816  VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3817  result.name),
3818  builder.getUnitAttr());
3819  }
3820  result.addTypes(transform::AnyOpType::get(builder.getContext()));
3821 }
3822 
3823 namespace {
3824 /// This is an helper only to call vectorize via a pattern inside of
3825 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3826 struct VectorizationPattern : public RewritePattern {
3827  explicit VectorizationPattern(MLIRContext *context,
3828  bool vectorizeExtract = false,
3829  bool flattenConv = false)
3830  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3831  vectorizeNDExtract(vectorizeExtract),
3832  flatten1DDepthwiseConv(flattenConv) {}
3833  LogicalResult matchAndRewrite(Operation *op,
3834  PatternRewriter &rewriter) const override {
3836  return rewriter.notifyMatchFailure(op,
3837  "Unsupported Op, cannot vectorize");
3838  FailureOr<VectorizationResult> vectorResults =
3839  vectorize(rewriter, op, /*inputVectorSizes=*/{},
3840  /*inputScalableVecDims=*/{}, vectorizeNDExtract,
3841  flatten1DDepthwiseConv);
3842  if (failed(vectorResults))
3843  return failure();
3844  rewriter.replaceOp(op, vectorResults->replacements);
3845  return success();
3846  }
3847 
3848 private:
3849  /// Controls whether to vectorize `tensor.extract` when the input tensor is
3850  /// rank >= 2.
3851  bool vectorizeNDExtract = false;
3852  /// Controls whether to "flatten" the channel dimension when vectorising 1D
3853  /// depthwise convolutions. This should lead to bette vectorization for
3854  /// tensors with a low number of channel dimensions.
3855  bool flatten1DDepthwiseConv = false;
3856 };
3857 } // namespace
3858 
3860 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3861  transform::TransformRewriter &rewriter, Operation *target,
3863  transform::TransformState &state) {
3864  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3865  auto diag = this->emitOpError("requires isolated-from-above targets");
3866  diag.attachNote(target->getLoc()) << "non-isolated target";
3868  }
3869 
3870  MLIRContext *ctx = getContext();
3872  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3873  getFlatten_1dDepthwiseConv());
3874 
3875  if (!getDisableTransferPermutationMapLoweringPatterns())
3877 
3878  if (!getDisableMultiReductionToContractPatterns())
3880 
3882 
3885  /*benefit=*/2);
3886  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3887  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3889 
3891 
3892  if (getFoldTypeExtensionsIntoContract())
3894 
3895  if (getVectorizePadding()) {
3897  // This creates an alternative path for lowering tensor.pad - by
3898  // decomposing it into e.g. linalg.fill.
3900  }
3902 
3903  TrackingListener listener(state, *this);
3904  if (failed(
3905  applyPatternsGreedily(target, std::move(patterns),
3906  GreedyRewriteConfig().setListener(&listener))))
3907  return emitDefaultDefiniteFailure(target);
3908 
3909  results.push_back(target);
3911 }
3912 
3913 //===----------------------------------------------------------------------===//
3914 // VectorizeOp
3915 //===----------------------------------------------------------------------===//
3916 
3917 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3918  transform::TransformRewriter &rewriter,
3919  mlir::transform::TransformResults &transformResults,
3921  auto targets = state.getPayloadOps(getTarget());
3922  if (std::empty(targets))
3924  auto transformOp = cast<TransformOpInterface>(getOperation());
3925  SmallVector<int64_t> vectorSizes;
3927  state, transformOp, getMixedVectorSizes(), vectorSizes);
3928  if (!status.succeeded())
3929  return status;
3930 
3931  // TODO: Check that the correct number of vectorSizes was provided.
3932  for (Operation *target : targets) {
3933  if (!linalg::hasVectorizationImpl(target)) {
3934  return mlir::emitSilenceableFailure(target->getLoc())
3935  << "Unsupported Op, cannot vectorize";
3936  }
3937  FailureOr<VectorizationResult> vectorResults =
3938  linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
3939  getVectorizeNdExtract().value_or(false),
3940  /*flatten1DDepthwiseConv=*/false,
3941  getAssumeDynamicDimsMatchVecSizes().value_or(false),
3942  getCreateNamedContraction().value_or(false));
3943  if (failed(vectorResults)) {
3944  return mlir::emitSilenceableFailure(target->getLoc())
3945  << "Attempted to vectorize, but failed";
3946  }
3947  rewriter.replaceOp(target, vectorResults->replacements);
3948  }
3949 
3951 }
3952 
3953 void transform::VectorizeOp::getEffects(
3955  consumesHandle(getTargetMutable(), effects);
3956  onlyReadsHandle(getVectorSizesMutable(), effects);
3957  modifiesPayload(effects);
3958 }
3959 
3960 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3961  OpBuilder b(getContext());
3962  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3963 }
3964 
3965 LogicalResult transform::VectorizeOp::verify() {
3966  if (getStaticVectorSizes().size() != getScalableSizes().size())
3967  return emitOpError("expected same number of vector sizes (")
3968  << getStaticVectorSizes().size() << ") and scalable sizes ("
3969  << getScalableSizes().size() << ")";
3970  return success();
3971 }
3972 
3973 //===----------------------------------------------------------------------===//
3974 // HoistRedundantVectorTransfersOp
3975 //===----------------------------------------------------------------------===//
3976 
3978 transform::HoistRedundantVectorTransfersOp::applyToOne(
3979  transform::TransformRewriter &rewriter, func::FuncOp target,
3981  transform::TransformState &state) {
3982  // WARNING: This hoisting does not model parallelism and is generally
3983  // incorrect when used on distributed loops with memref semantics!
3984  // TODO: obsolete and should be retired.
3985  linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
3986  results.push_back(target);
3988 }
3989 
3990 //===----------------------------------------------------------------------===//
3991 // HoistRedundantVectorBroadcastsOp
3992 //===----------------------------------------------------------------------===//
3993 
3995 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3996  transform::TransformRewriter &rewriter, mlir::Operation *target,
3998  transform::TransformState &state) {
3999  rewriter.setInsertionPoint(target);
4000  linalg::hoistRedundantVectorBroadcasts(rewriter, target);
4001  results.push_back(target);
4003 }
4004 
4005 //===----------------------------------------------------------------------===//
4006 // ConvertConv2DToImg2ColOp.
4007 //===----------------------------------------------------------------------===//
4008 
4009 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
4010  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4012  transform::TransformState &state) {
4013  rewriter.setInsertionPoint(target);
4014  auto maybeTransformed =
4016  target)
4017  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4018  return rewriteInIm2Col(rewriter, op);
4019  })
4020  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4021  return rewriteInIm2Col(rewriter, op);
4022  })
4023  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4024  return rewriteInIm2Col(rewriter, op);
4025  })
4026  .Case([&](linalg::Conv2DNchwFchwOp op) {
4027  return rewriteInIm2Col(rewriter, op);
4028  })
4029  .Default([&](Operation *op) {
4030  return rewriter.notifyMatchFailure(op, "not supported");
4031  });
4032  if (failed(maybeTransformed))
4033  return emitDefaultSilenceableFailure(target);
4034  // Handle to the operation producing the img2col tensor.
4035  results.push_back(maybeTransformed->first);
4036  // Handle to the operation that replaces the original convolution.
4037  results.push_back(maybeTransformed->second);
4039 }
4040 
4041 //===----------------------------------------------------------------------===//
4042 // FlattenElementwiseLinalgOp.
4043 //===----------------------------------------------------------------------===//
4044 
4045 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4046  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4048  transform::TransformState &state) {
4049  rewriter.setInsertionPoint(target);
4050  if (!isElementwise(target))
4051  return mlir::emitSilenceableFailure(target->getLoc())
4052  << "only elementwise flattening is supported";
4053 
4054  // If rank <= 1, do nothing
4055  if (target.getNumLoops() <= 1) {
4056  results.push_back(target);
4058  }
4059 
4060  // Attempt to flatten all dims to one.
4061  ReassociationIndices reassociation(target.getNumLoops());
4062  std::iota(reassociation.begin(), reassociation.end(), 0);
4063  auto maybeFlattened =
4064  collapseOpIterationDims(target, reassociation, rewriter);
4065  if (failed(maybeFlattened))
4066  return mlir::emitSilenceableFailure(target->getLoc())
4067  << "attempted to flatten, but failed";
4068  results.push_back(maybeFlattened->collapsedOp);
4069  rewriter.replaceOp(target, maybeFlattened->results);
4071 }
4072 
4073 //===----------------------------------------------------------------------===//
4074 // TransposeConv2DOp
4075 //===----------------------------------------------------------------------===//
4076 
4077 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4078  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4080  transform::TransformState &state) {
4081  rewriter.setInsertionPoint(target);
4082  auto maybeTransformed =
4084  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4085  return transposeConv2D(rewriter, op);
4086  })
4087  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4088  return transposeConv2D(rewriter, op);
4089  })
4090  .Default([&](Operation *op) {
4091  return rewriter.notifyMatchFailure(op, "not supported");
4092  });
4093  if (failed(maybeTransformed))
4094  return emitDefaultSilenceableFailure(target);
4095  // Handle to the new Conv2D operation with transposed filters
4096  results.push_back(*maybeTransformed);
4098 }
4099 
4100 //===----------------------------------------------------------------------===//
4101 // TransposeMatmulOp
4102 //===----------------------------------------------------------------------===//
4103 
4104 DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4105  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4107  transform::TransformState &state) {
4108  rewriter.setInsertionPoint(target);
4109  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4110  auto maybeTransformed =
4112  .Case([&](linalg::MatmulOp op) {
4113  return transposeMatmul(rewriter, op, transposeLHS);
4114  })
4115  .Case([&](linalg::BatchMatmulOp op) {
4116  return transposeBatchMatmul(rewriter, op, transposeLHS);
4117  })
4118  .Default([&](Operation *op) { return failure(); });
4119  if (failed(maybeTransformed))
4120  return emitSilenceableFailure(target->getLoc()) << "not supported";
4121  // Handle to the new Matmul operation with transposed filters
4122  results.push_back(*maybeTransformed);
4124 }
4125 
4126 //===----------------------------------------------------------------------===//
4127 // InsertSliceToCopyOp
4128 //===----------------------------------------------------------------------===//
4129 template <typename OpTy>
4132  transform::TransformState &state) {
4133  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4134  tensor::ParallelInsertSliceOp>() &&
4135  "wrong op type");
4136 
4137  if (auto copySource =
4138  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4139  results.push_back(copySource);
4141  }
4142 
4143  // If we are inside an InParallel region, temporarily set the insertion point
4144  // outside: only tensor.parallel_insert_slice ops are allowed in there.
4145  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
4146  rewriter.setInsertionPoint(
4147  target->template getParentOfType<scf::InParallelOp>());
4148  }
4149 
4150  Value extracted = tensor::ExtractSliceOp::create(
4151  rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
4152  target.getMixedSizes(), target.getMixedStrides());
4153  Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
4154  target.getSource(), extracted)
4155  .getResult(0);
4156  // Reset the insertion point.
4157  rewriter.setInsertionPoint(target);
4158  rewriter.replaceOpWithNewOp<OpTy>(
4159  target, copied, target.getDest(), target.getMixedOffsets(),
4160  target.getMixedSizes(), target.getMixedStrides());
4161 
4162  results.push_back(copied.getDefiningOp());
4164 }
4165 
4166 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4167  transform::TransformRewriter &rewriter, Operation *targetOp,
4169  transform::TransformState &state) {
4170 
4171  rewriter.setInsertionPoint(targetOp);
4172  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4173  return doit(rewriter, target, results, state);
4174  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4175  return doit(rewriter, target, results, state);
4176 
4178  emitSilenceableError()
4179  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4180  diag.attachNote(targetOp->getLoc()) << "target op";
4181  return diag;
4182 }
4183 
4184 //===----------------------------------------------------------------------===//
4185 // MapCopyToThreadsOp
4186 //===----------------------------------------------------------------------===//
4187 
4188 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4189  transform::TransformRewriter &rewriter, Operation *target,
4191  transform::TransformState &state) {
4192  // Check if the op is supported.
4193  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4195  emitSilenceableError()
4196  << "only linalg.copy and tensor.pad target ops are supported";
4197  diag.attachNote(target->getLoc()) << "target op";
4198  return diag;
4199  }
4200  assert(target->getNumResults() == 1 && "expected single result");
4201  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
4202  if (!resultShapedType.hasStaticShape()) {
4204  emitSilenceableError()
4205  << "only statically sized ops of rank <= 3 are supported";
4206  diag.attachNote(target->getLoc()) << "target op";
4207  return diag;
4208  }
4209 
4210  // Conservatively set the minimum viable desired bitwidth alignment.
4211  int64_t desiredBitAlignment = getDesiredBitAlignment();
4212  int64_t eltBitwidth =
4213  resultShapedType.getElementType().getIntOrFloatBitWidth();
4214  if (desiredBitAlignment % eltBitwidth != 0) {
4215  desiredBitAlignment = eltBitwidth;
4216  }
4217 
4218  gpu::CopyMappingInfo mapping(
4219  /*ctx=*/getContext(),
4220  /*totalNumThreads=*/getTotalNumThreads(),
4221  /*alignment=*/desiredBitAlignment,
4222  /*sizes=*/resultShapedType.getShape(),
4223  /*favorPredication=*/false,
4224  /*elementalBitwidth=*/
4225  resultShapedType.getElementType().getIntOrFloatBitWidth());
4226  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4228  emitSilenceableError()
4229  << "too few threads to map copy op to threads on the most minor "
4230  "dimension, given alignment and vector size constraints, try "
4231  "smaller tile size of mapping to more threads";
4232  diag.attachNote(target->getLoc()) << "target op";
4233  return diag;
4234  }
4235 
4236  // OpBuilder only used to compute attributes.
4237  OpBuilder b(getContext());
4238  scf::SCFTilingResult tilingResult;
4240  /*rewriter=*/rewriter,
4241  /*state=*/state,
4242  /*transformOp=*/*this,
4243  /*target=*/target,
4244  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
4245  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
4246  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
4247  /*tilingResult=*/tilingResult);
4248  if (!diag.succeeded())
4249  return diag;
4250 
4251  results.push_back(tilingResult.loops.front());
4252  for (auto op : tilingResult.tiledOps)
4253  results.push_back(op);
4255 }
4256 
4257 //===----------------------------------------------------------------------===//
4258 // WinogradConv2DOp
4259 //===----------------------------------------------------------------------===//
4260 
4261 DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4262  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4264  transform::TransformState &state) {
4265  rewriter.setInsertionPoint(target);
4266  FailureOr<Operation *> maybeTransformed = failure();
4267  bool supported = TypeSwitch<Operation *, bool>(target)
4268  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4269  maybeTransformed =
4270  winogradConv2D(rewriter, op, getFmr());
4271  return true;
4272  })
4273  .Default([&](Operation *op) { return false; });
4274 
4275  if (!supported) {
4276  return emitSilenceableError()
4277  << "this operation is not supported to convert to Winograd Conv2D";
4278  }
4279 
4280  if (failed(maybeTransformed)) {
4281  return emitSilenceableError() << "apply Winograd Conv2D failed";
4282  }
4283 
4284  results.push_back(*maybeTransformed);
4286 }
4287 
4288 DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4289  transform::TransformRewriter &rewriter, Operation *target,
4291  transform::TransformState &state) {
4292  rewriter.setInsertionPoint(target);
4293  FailureOr<Operation *> maybeTransformed = failure();
4294  bool supported =
4296  .Case([&](linalg::WinogradFilterTransformOp op) {
4297  maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
4298  return true;
4299  })
4300  .Case([&](linalg::WinogradInputTransformOp op) {
4301  maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
4302  return true;
4303  })
4304  .Case([&](linalg::WinogradOutputTransformOp op) {
4305  maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
4306  return true;
4307  })
4308  .Default([&](Operation *op) { return false; });
4309 
4310  if (!supported) {
4312  emitSilenceableError()
4313  << "this operation is not supported to decompose into other operations";
4314  diag.attachNote(target->getLoc()) << "target op";
4315  return diag;
4316  }
4317 
4318  if (failed(maybeTransformed)) {
4320  emitSilenceableError() << "decompose Winograd operations failed";
4321  diag.attachNote(target->getLoc()) << "target op";
4322  return diag;
4323  }
4324 
4325  results.push_back(*maybeTransformed);
4327 }
4328 
4329 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4330 
4331 #define GET_OP_CLASSES
4332 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Definition: AMXDialect.cpp:70
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static bool sameOrEquivalentIterArg(Value src, Value dst)
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:166
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:367
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:280
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:305
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
This class represents a saved insertion point.
Definition: Builders.h:327
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:337
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:316
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getOpResult(unsigned idx)
Definition: Operation.h:421
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1276
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1374
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
Definition: Padding.cpp:244
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:471
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:657
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:247
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:346
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:857
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:512
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:496
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:487
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:400
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:504
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:220
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:45
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:236
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition: Tiling.cpp:156
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition: Tiling.cpp:106
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:89
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:748
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:464
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:422
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:491
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:480
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:198
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:262
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:217
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:617
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:61
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:114
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:285
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:421
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:422
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:522
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1626
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1558
Match and rewrite for the pattern:
Definition: Transforms.h:1748
Match and rewrite for the pattern:
Definition: Transforms.h:1776
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:421
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:434
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:447
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:467
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:410
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:441
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:457
LinalgPromotionOptions & setUseOriginalSubviewSize(bool originalSize)
Definition: Transforms.h:428
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:399
Split Reduction options.
Definition: Transforms.h:476
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.