MLIR 23.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
40#include "mlir/Support/LLVM.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/ScopeExit.h"
44#include "llvm/ADT/SmallPtrSet.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/DebugLog.h"
48#include "llvm/Support/LogicalResult.h"
49#include <type_traits>
50
51using namespace mlir;
52using namespace mlir::linalg;
53using namespace mlir::transform;
54
55#define DEBUG_TYPE "linalg-transforms"
56
57/// Attempts to apply the pattern specified as template argument to the given
58/// operation. The pattern is expected to have a `returningMatchAndRewrite`
59/// function that returns the "main" result or failure. Returns failure if the
60/// pattern failed to apply. Extra arguments are forwarded to the pattern
61/// constructor.
62template <typename PatternTy, typename... Args>
63static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
64 // Check if the given operation has the type expected by the pattern.
65 using OpTy = typename llvm::function_traits<
66 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
67 auto op = dyn_cast<OpTy>(operation);
68 if (!op)
69 return failure();
70
71 // Apply the pattern directly to the op.
72 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
73 // We want to discourage direct use of PatternRewriter in APIs but In this
74 // very specific case, an IRRewriter is not enough.
75 PatternRewriter rewriter(operation->getContext());
76 rewriter.setInsertionPoint(operation);
77 auto result = pattern.returningMatchAndRewrite(op, rewriter);
78 if (failed(result))
79 return failure();
80 return cast<LinalgOp>(result->getOperation());
81}
82
83/// Assuming that `ofr` is an index attr or a param of index type
84/// or a transform dialect handle mapped to exactly one op
85/// with one index result, return that value.
87 transform::TransformState &state, TransformOpInterface transformOp,
89 for (OpFoldResult ofr : ofrs) {
90 if (auto attr = dyn_cast<Attribute>(ofr)) {
91 if (!isa<IntegerAttr>(attr))
92 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
93 result.push_back(ofr);
94 continue;
95 }
96
97 Value transformValue = cast<Value>(ofr);
98 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
99 ArrayRef<Attribute> params = state.getParams(transformValue);
100 if (params.size() != 1)
101 return transformOp.emitDefiniteFailure()
102 << "requires exactly one parameter associated";
103 result.push_back(params[0]);
104 continue;
105 }
106
107 auto payloadOps = state.getPayloadOps(transformValue);
108 if (!llvm::hasSingleElement(payloadOps)) {
110 transformOp.emitSilenceableError()
111 << "handle must be mapped to exactly one payload op";
112 diag.attachNote(transformValue.getLoc())
113 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
114 return diag;
115 }
116
117 Operation *op = *payloadOps.begin();
118 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
120 transformOp.emitSilenceableError()
121 << "payload op must have exactly 1 index result";
122 diag.attachNote(op->getLoc())
123 << "has " << op->getNumResults() << " results";
124 return diag;
125 }
126 result.push_back(op->getResult(0));
127 }
128
130}
131
132// Given a list of params that are index attrs or a list of OpFoldResults
133// that are either index attrs or op handles, return a list of OpFoldResults
134// of index attrs or a list of OpFoldResults where all op handles are
135// replaced with the first (and only) OpResult of that payload op.
136// (There must be exactly one parameter associated with the AnyParamType or
137// one mapped payload op which must have exactly one index result.)
139 transform::TransformState &state, TransformOpInterface transformOp,
140 SmallVector<OpFoldResult> &result, Value packedHandle) {
141 if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
142 ArrayRef<Attribute> params = state.getParams(packedHandle);
143 for (auto param : params) {
144 if (!isa<IntegerAttr>(param))
145 return transformOp.emitDefiniteFailure()
146 << "expected the parameter to be associated with an integer "
147 "attribute";
148 result.push_back(param);
149 }
151 }
152
153 for (Operation *op : state.getPayloadOps(packedHandle)) {
154 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
156 transformOp.emitSilenceableError()
157 << "payload op must have exactly 1 index result";
158 diag.attachNote(op->getLoc())
159 << "has " << op->getNumResults() << " results";
160 return diag;
161 }
162 result.push_back(op->getResult(0));
163 }
164
166}
167
168/// When possible, converts each `OpFoldResult` in `mixedResult` to
169/// an integer if the value can be statically inferred. If a result
170/// is a `Value` then it must be either a `ParamType` or a handle
171/// to an a constant like op.
173 TransformState &state, TransformOpInterface &transformOp,
174 ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
175 for (OpFoldResult paramOrHandle : mixedResults) {
176 if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
177 reified.push_back(cast<IntegerAttr>(attr).getInt());
178 continue;
179 }
180 if (isa<TransformParamTypeInterface>(
181 cast<Value>(paramOrHandle).getType())) {
182 ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
183 if (params.size() != 1)
184 return transformOp.emitSilenceableError() << "expected a single param";
185 reified.push_back(
186 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
187 continue;
188 }
189
190 Value handle = cast<Value>(paramOrHandle);
191 if (!isa<TransformHandleTypeInterface>(handle.getType()))
192 return transformOp.emitSilenceableError() << "unexpected value handle";
193 auto payload = state.getPayloadOps(handle);
194 if (!llvm::hasSingleElement(payload))
195 return transformOp.emitSilenceableError()
196 << "requires param or handle that is mapped to 1 payload op";
197
198 Operation *paramOrHandlePayloadOp = *payload.begin();
199 if (paramOrHandlePayloadOp->getNumResults() != 1 ||
200 !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
201 return transformOp.emitSilenceableError()
202 << "requires param or handle to be result of op with 1 index "
203 "result";
204 }
205
206 IntegerAttr attr;
207 if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
208 return transformOp.emitSilenceableError()
209 << "requires param or handle to be the result of a constant like "
210 "op";
211
212 reified.push_back(attr.getInt());
213 }
215}
216
217//===----------------------------------------------------------------------===//
218// Apply...PatternsOp
219//===----------------------------------------------------------------------===//
220
221void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
222 RewritePatternSet &patterns) {
224}
225
226void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
227 RewritePatternSet &patterns) {
229}
230
231void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
232 RewritePatternSet &patterns) {
234}
235
236void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
237 RewritePatternSet &patterns) {
240}
241
242void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
243 RewritePatternSet &patterns) {
245 options.rankReductionStrategy =
248}
249
250void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
251 RewritePatternSet &patterns) {
253}
254
255void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
256 RewritePatternSet &patterns) {
258}
259
260void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
261 RewritePatternSet &patterns) {
263}
264
265void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
266 RewritePatternSet &patterns) {
268}
269
270void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
271 RewritePatternSet &patterns) {
273}
274
275void transform::ApplyDataLayoutPropagationPatternsOp::populatePatterns(
276 RewritePatternSet &patterns) {
277 linalg::ControlPropagationFn defaultControlFn = [](OpOperand *operand) {
278 return true;
279 };
280 linalg::populateDataLayoutPropagationPatterns(patterns, defaultControlFn,
281 getPoisonPadding());
282}
283
284void transform::ApplyExtractSliceSinkingPatternsOp::populatePatterns(
285 RewritePatternSet &patterns) {
286 linalg::ControlPropagationFn defaultControlFn =
287 [](OpOperand *opOperand) -> bool {
288 Operation *producer = opOperand->get().getDefiningOp();
289 Operation *consumer = opOperand->getOwner();
290 return consumer->getBlock() == producer->getBlock();
291 };
292 linalg::populateExtractSliceSinkingPatterns(patterns, defaultControlFn);
293}
294
295//===----------------------------------------------------------------------===//
296// BufferizeToAllocationOp
297//===----------------------------------------------------------------------===//
298
299namespace {
300class NewOpsListener : public RewriterBase::ForwardingListener {
301public:
303
304 SmallVector<Operation *> getNewOps() const {
305 return SmallVector<Operation *>(newOps.begin(), newOps.end());
306 }
307
308private:
309 void notifyOperationInserted(Operation *op,
310 OpBuilder::InsertPoint previous) override {
311 ForwardingListener::notifyOperationInserted(op, previous);
312 // We only care about newly created ops.
313 if (previous.isSet())
314 return;
315 auto inserted = newOps.insert(op);
316 (void)inserted;
317 assert(inserted.second && "expected newly created op");
318 }
319
320 void notifyOperationErased(Operation *op) override {
321 ForwardingListener::notifyOperationErased(op);
322 op->walk([&](Operation *op) { newOps.erase(op); });
323 }
324
326};
327} // namespace
328
329DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
332 // Attach listener to keep track of newly created ops.
333 OpBuilder::Listener *previousListener = rewriter.getListener();
334 llvm::scope_exit resetListener(
335 [&]() { rewriter.setListener(previousListener); });
336 NewOpsListener newOpsListener(previousListener);
337 rewriter.setListener(&newOpsListener);
338
340 if (getMemcpyOp() == "bufferization.materialize_in_destination") {
341 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
342 MaterializeInDestination;
343 } else if (getMemcpyOp() == "memref.copy") {
344 options.memcpyOp =
346 } else if (getMemcpyOp() == "linalg.copy") {
347 options.memcpyOp =
349 } else {
350 llvm_unreachable("invalid memcpy op");
351 }
352 if (getAllocOp() == "memref.alloc") {
353 options.allocOp =
355 } else if (getAllocOp() == "memref.alloca") {
356 options.allocOp =
358 } else {
359 llvm_unreachable("invalid alloc op");
360 }
361 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
362 options.emitDealloc = getEmitDealloc();
363
364 // Bufferize ops.
365 Attribute memorySpace =
366 getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
367 SmallVector<Value> allocatedBuffers;
368 for (Operation *op : state.getPayloadOps(getTarget())) {
369 Value buffer =
370 linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
371 if (!buffer) {
372 DiagnosedSilenceableFailure diag = emitSilenceableError()
373 << "failed to bufferize operation";
374 diag.attachNote(op->getLoc()) << "target payload op";
375 return diag;
376 }
377 allocatedBuffers.push_back(buffer);
378 }
379
380 // Set results.
381 results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
382 results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
384}
385
386void transform::BufferizeToAllocationOp::getEffects(
388 if (getBufferizeDestinationOnly()) {
389 // The destination is replaced with a newly allocated buffer, but the op
390 // itself remains in place.
391 onlyReadsHandle(getTargetMutable(), effects);
392 } else {
393 consumesHandle(getTargetMutable(), effects);
394 }
395 producesHandle(getOperation()->getOpResults(), effects);
396 modifiesPayload(effects);
397}
398
399LogicalResult transform::BufferizeToAllocationOp::verify() {
400 if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
401 getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
402 return emitOpError() << "unsupported memcpy op";
403 if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
404 return emitOpError() << "unsupported alloc op";
405 return success();
406}
407
408//===----------------------------------------------------------------------===//
409// PromoteTensorOp
410//===----------------------------------------------------------------------===//
411
412/// Return true if the operand may be read from by its owner. This is currently
413/// very conservative and only looks inside linalg operations to prevent
414/// unintentional data loss.
415static bool mayBeRead(OpOperand &operand) {
416 auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
417
418 // Be conservative about ops we cannot analyze deeper.
419 if (!linalgOp)
420 return true;
421
422 // Look inside linalg ops.
423 Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
424 return !blockArgument.use_empty();
425}
426
427/// Return true if the value may be read through any of its uses.
428static bool mayBeRead(Value value) {
429 // If the value has a reference semantics, it
430 // may be read through any alias...
431 if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
432 return true;
433 return llvm::any_of(value.getUses(),
434 static_cast<bool (&)(OpOperand &)>(mayBeRead));
435}
436
438transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
441 SmallVector<Value> promoted;
442 for (Value tensor : state.getPayloadValues(getTensor())) {
443 auto type = dyn_cast<RankedTensorType>(tensor.getType());
444 if (!type) {
445 return emitSilenceableError() << "non-tensor type: " << tensor;
446 }
447
448 Operation *definingOp = tensor.getDefiningOp();
449 if (definingOp)
450 rewriter.setInsertionPointAfter(definingOp);
451 else
452 rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
453
454 // Check this before we emit operations using this value.
455 bool needsMaterialization = mayBeRead(tensor);
456
457 SmallVector<Value> dynamicDims;
459 for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
460 if (!ShapedType::isDynamic(dim))
461 continue;
462 Value cst =
463 arith::ConstantIndexOp::create(rewriter, tensor.getLoc(), pos);
464 auto dimOp =
465 tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst);
466 preservedOps.insert(dimOp);
467 dynamicDims.push_back(dimOp);
468 }
469 auto allocation = bufferization::AllocTensorOp::create(
470 rewriter, tensor.getLoc(), type, dynamicDims);
471 // Set memory space if provided.
472 if (getMemorySpaceAttr())
473 allocation.setMemorySpaceAttr(getMemorySpaceAttr());
474 Value allocated = allocation;
475
476 // Only insert a materialization (typically bufferizes to a copy) when the
477 // value may be read from.
478 if (needsMaterialization) {
479 auto copy = bufferization::MaterializeInDestinationOp::create(
480 rewriter, tensor.getLoc(), tensor, allocated);
481 preservedOps.insert(copy);
482 promoted.push_back(copy.getResult());
483 } else {
484 promoted.push_back(allocated);
485 }
486 rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
487 }
488 results.setValues(cast<OpResult>(getPromoted()), promoted);
490}
491
492void transform::PromoteTensorOp::getEffects(
494 transform::onlyReadsHandle(getTensorMutable(), effects);
495 transform::producesHandle(getOperation()->getOpResults(), effects);
497}
498
499//===----------------------------------------------------------------------===//
500// DecomposeOp
501//===----------------------------------------------------------------------===//
502
504transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
505 LinalgOp target,
508 FailureOr<linalg::LinalgOp> res =
510 if (succeeded(res)) {
511 results.push_back(*res);
513 }
514 return emitDefaultSilenceableFailure(target);
515}
516
517//===----------------------------------------------------------------------===//
518// DecomposeInterfaceOp
519//===----------------------------------------------------------------------===//
520
521// Decompose the target operation if it implements the AggregatedOpInterface.
522// Push the decomposed operations (the ones that replaces the values produced by
523// \p target) in the `results`.
524DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
528 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
529 if (!decomposableOp) {
531 "payload is not a decomposable op"));
532 return emitDefaultSilenceableFailure(target);
533 }
534
535 FailureOr<SmallVector<Value>> maybeNewResults =
536 decomposableOp.decomposeOperation(rewriter);
537 if (failed(maybeNewResults))
538 return emitDefaultSilenceableFailure(target);
539
540 rewriter.replaceOp(decomposableOp, *maybeNewResults);
541 for (Value val : *maybeNewResults) {
542 Operation *definition = val.getDefiningOp();
543 if (definition)
544 results.push_back(definition);
545 }
547}
548
549//===----------------------------------------------------------------------===//
550// EliminateLinalgOpAnchoredEmptyTensorsOp
551//===----------------------------------------------------------------------===//
552
553void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
555 onlyReadsHandle(getTargetMutable(), effects);
556 modifiesPayload(effects);
557}
558
560transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
561 transform::TransformRewriter &rewriter, TransformResults &transformResults,
562 TransformState &state) {
564 options.allowReturnAllocsFromLoops = true;
565
566 for (Operation *target : state.getPayloadOps(getTarget())) {
568 if (failed(analyzeOp(target, state)))
569 return mlir::emitSilenceableFailure(target->getLoc())
570 << "failed to analyze op";
572 rewriter, target, state)))
573 return mlir::emitSilenceableFailure(target->getLoc())
574 << "failed to eliminate LinalgOp anchored tensor.empty ops";
575 }
577}
578
579//===----------------------------------------------------------------------===//
580// FuseOp
581//===----------------------------------------------------------------------===//
582
583void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
584 TypeRange loopTypes, Value target,
585 ArrayRef<int64_t> staticTileSizes,
586 ArrayRef<int64_t> staticTileInterchange,
587 bool applyCleanup, bool useForall) {
588 return build(
589 builder, result, loopTypes,
590 /*target=*/target,
591 /*mixedTileSizes=*/
592 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
593 /*mixedTileInterchange=*/
594 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
595 applyCleanup, useForall);
596}
597
598void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
599 Value target, ArrayRef<int64_t> staticTileSizes,
600 ArrayRef<int64_t> staticTileInterchange,
601 bool applyCleanup, bool useForall) {
602 return build(
603 builder, result,
604 /*target=*/target,
605 /*mixedTileSizes=*/
606 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
607 /*mixedTileInterchange=*/
608 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
609 applyCleanup, useForall);
610}
611
612void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
614 ArrayRef<OpFoldResult> mixedTileSizes,
615 ArrayRef<OpFoldResult> mixedTileInterchange,
616 bool applyCleanup, bool useForall) {
617 // Loop types are automaticaly splat by the callee, setting up one is
618 // enough.
619 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
620 build(builder, result, loopTypes, target, mixedTileSizes,
621 mixedTileInterchange, applyCleanup, useForall);
622}
623
624void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
625 TypeRange loopTypes, Value target,
626 ArrayRef<OpFoldResult> mixedTileSizes,
627 ArrayRef<OpFoldResult> mixedTileInterchange,
628 bool applyCleanup, bool useForall) {
629 SmallVector<int64_t> staticTileSizes;
630 SmallVector<Value> dynamicTileSizes;
631 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
632 SmallVector<int64_t> staticTileInterchange;
633 SmallVector<Value> dynamicTileInterchange;
634 dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
635 staticTileInterchange);
636 // Call the default builder which sets up the proper operands segment sizes
637 // attributes for multiple variadic operands. In the absence of this,
638 // horrible bugs ensue.
639 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
640 auto staticTileInterchangeAttr =
641 builder.getDenseI64ArrayAttr(staticTileInterchange);
642 unsigned numExpectedLoops =
643 useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
644 SmallVector<Type> resultTypes;
645 resultTypes.reserve(numExpectedLoops);
646 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
647 "expected one loop type or as many as loops");
648 if (loopTypes.size() == 1)
649 resultTypes.append(numExpectedLoops, loopTypes[0]);
650 else
651 llvm::append_range(resultTypes, loopTypes);
652 build(builder, result, /*transformed=*/target.getType(),
653 /*loops=*/resultTypes,
654 /*target=*/target,
655 /*tile_sizes=*/dynamicTileSizes,
656 /*tile_interchange=*/dynamicTileInterchange,
657 /*static_tile_sizes=*/staticTileSizesAttr,
658 /*static_tile_interchange=*/staticTileInterchangeAttr,
659 /*apply_cleanup=*/applyCleanup,
660 /*use_forall=*/useForall);
661}
662
663/// Apply a tiling transformation to all payload ops and store both the
664/// tiled operation as well as the created tile loops.
665template <typename Range>
666static LogicalResult applyTilingToAll(
667 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
668 unsigned numLoops, transform::TransformResults &transformResults,
669 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
670 applyFn) {
671 SmallVector<Operation *> tiledLinalgOps;
672 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
673
674 for (Operation *target : payloadOps) {
675 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
676 if (!tilingInterfaceOp)
677 return transformOp->emitError("only TilingInterface ops are supported");
678
679 rewriter.setInsertionPoint(target);
680 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
681 applyFn(tilingInterfaceOp);
682 if (failed(tiledResults))
683 return failure();
684
685 // Perform the replacement of tiled and fused values.
686 SmallVector<Operation *> opsToReplace{target};
687 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
688 for (Operation *toReplace : opsToReplace) {
689 for (OpResult res : toReplace->getResults())
690 if (auto replacement = tiledResults->replacements.lookup(res))
691 rewriter.replaceAllUsesWith(res, replacement);
692 if (toReplace->use_empty()) {
693 rewriter.eraseOp(toReplace);
694 }
695 }
696
697 // Report back the relevant handles to the transform op.
698 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
699 assert(tiledResults->loops.size() == numLoops &&
700 "Mismatched number of loops, tile and fuse transform should have "
701 "failed");
702 for (unsigned int i = 0; i < numLoops; ++i)
703 loopOps[i].push_back(tiledResults->loops[i]);
704 }
705
706 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
707 for (unsigned int i = 0; i < numLoops; ++i)
708 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
709
710 return success();
711}
712
714transform::FuseOp::apply(transform::TransformRewriter &rewriter,
715 mlir::transform::TransformResults &transformResults,
717 auto transformOp = cast<TransformOpInterface>(getOperation());
718
719 SmallVector<int64_t> tileSizes;
721 state, transformOp, getMixedTileSizes(), tileSizes);
722 if (!status.succeeded())
723 return status;
724 SmallVector<int64_t> tileInterchange;
726 state, transformOp, getMixedTileInterchange(), tileInterchange);
727 if (!status.succeeded())
728 return status;
729
730 scf::SCFTilingOptions tilingOptions;
731 tilingOptions.interchangeVector = tileInterchange;
732 bool useForall = getUseForall();
733 tilingOptions.setLoopType(useForall
734 ? scf::SCFTilingOptions::LoopType::ForallOp
735 : scf::SCFTilingOptions::LoopType::ForOp);
736 SmallVector<OpFoldResult> tileSizesOfr =
737 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
738 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
739 scf::SCFTileAndFuseOptions tileAndFuseOptions;
740 tileAndFuseOptions.tilingOptions = tilingOptions;
741
742 if (getApplyCleanup()) {
743 MLIRContext *context = rewriter.getContext();
744 RewritePatternSet patterns(context);
745 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
748 tileAndFuseOptions.cleanupPatterns = std::move(patterns);
749 }
750
751 size_t numLoops =
752 useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
753 LogicalResult result = applyTilingToAll(
754 rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
755 transformResults,
756 [&](TilingInterface tilingInterfaceOp)
757 -> FailureOr<scf::SCFTileAndFuseResult> {
758 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
759 tileAndFuseOptions);
760 });
763}
764
765LogicalResult transform::FuseOp::verify() {
766 auto iterspace_rank = getStaticTileSizes().size();
767 ArrayRef<int64_t> permutation = getStaticTileInterchange();
768 if (permutation.size() > iterspace_rank)
769 return emitOpError()
770 << "interchange length exceeds iteration space dimensions ("
771 << iterspace_rank << "), found " << getTileInterchange();
772 SmallVector<bool> seen(iterspace_rank, false);
773 for (int64_t v : permutation) {
774 if (!ShapedType::isDynamic(v)) {
775 if (v < 0 || v >= static_cast<int64_t>(iterspace_rank))
776 return emitOpError() << "expects interchange values to be in range [0, "
777 << iterspace_rank << "), found: " << v;
778 if (seen[v])
779 return emitOpError() << "found duplicate interchange value: " << v;
780 seen[v] = true;
781 }
782 }
783
784 ArrayRef<int64_t> sizes = getStaticTileSizes();
785 size_t numExpectedLoops =
786 getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
787 if (numExpectedLoops != getNumResults() - 1)
788 return emitOpError() << "expects " << numExpectedLoops << " loop results";
789
790 return success();
791}
792
793SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
794 return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext());
795}
796
797SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
798 return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
799 getContext());
800}
801
802void transform::FuseOp::getEffects(
804 consumesHandle(getTargetMutable(), effects);
805 onlyReadsHandle(getTileSizesMutable(), effects);
806 onlyReadsHandle(getTileInterchangeMutable(), effects);
807 producesHandle(getOperation()->getOpResults(), effects);
808 modifiesPayload(effects);
809}
810
811//===----------------------------------------------------------------------===//
812// FuseIntoContainingOp
813//===----------------------------------------------------------------------===//
814
815void transform::FuseIntoContainingOp::build(OpBuilder &builder,
817 Value producerOp,
818 Value containingOp) {
819 result.addOperands({producerOp, containingOp});
820 auto resultType = transform::AnyOpType::get(builder.getContext());
821 result.addTypes({resultType, resultType});
822}
823
824/// Add new operands to the forall op for users of the producerOp
825/// that are dominated by the containing scf.forall op.
827 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
828 Operation *containingOp, TilingResult &tileAndFuseResult,
829 int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
831
832 // Count number of users not including the containing op
833 SetVector<Operation *> dominatedUsers;
834 DominanceInfo domInfo(containingOp);
835 for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
836 if (!containingOp->isAncestor(user) &&
837 (domInfo.dominates(containingOp, user))) {
838 dominatedUsers.insert(user);
839 }
840 }
841 if (dominatedUsers.empty())
842 return nullptr;
843
844 // Create new scf.forall op
845 auto forallOp = cast<scf::ForallOp>(containingOp);
846 OpBuilder::InsertionGuard g(rewriter);
847 rewriter.setInsertionPoint(forallOp);
848
849 // Get new output
850 Location loc = forallOp.getLoc();
851 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
852 if (!genericOp)
853 return nullptr;
854 SmallVector<Value> outputs = genericOp.getOutputs();
855 SmallVector<Value> newOuts(forallOp.getOutputs());
856 newOuts.push_back(outputs[resultNumber]);
857
858 // Create new scf.forall op
859 auto newforallOp = scf::ForallOp::create(
860 rewriter, loc, forallOp.getMixedLowerBound(),
861 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
862 forallOp.getMapping());
863 rewriter.eraseBlock(newforallOp.getBody());
864 newforallOp.getRegion().takeBody(forallOp.getRegion());
865
866 // Add additional block argument for new value being returned
867 // and replaces all uses of the new output with corresponding bbArg
868 // inside the scf.forall to enable fusion into this new scf.forall.
869 newforallOp.getBody()->addArgument(newOuts.back().getType(),
870 newOuts.back().getLoc());
871 auto bbArgs = newforallOp.getBody()->getArguments();
872 rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
873 [&](OpOperand &use) {
874 Operation *op = use.getOwner();
875 return newforallOp->isProperAncestor(op);
876 });
877
878 // Fix terminator
879 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
880 SmallVector<Operation *> yieldingOps = llvm::map_to_vector<4>(
881 terminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
882 Operation *firstYieldOp = yieldingOps.front();
883 rewriter.setInsertionPoint(firstYieldOp);
884 Value src = tileAndFuseResult.tiledValues[0];
885 Value dst = newforallOp.getRegionIterArgs().back();
886 SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
887 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->getLoc(), src,
888 dst, offsets, sizes, strides);
889
890 for (auto result : llvm::enumerate(forallOp.getResults())) {
891 rewriter.replaceAllUsesWith(result.value(),
892 newforallOp->getResult(result.index()));
893 }
894 rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
895 newforallOp->getResults().back(),
896 [&](OpOperand &use) {
897 Operation *user = use.getOwner();
898 return dominatedUsers.contains(user);
899 });
900 return newforallOp;
901}
902
903/// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
904/// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
905/// outer loop. To determine the second condition, this function iterates
906/// using a worklist over the enclosing loops, trying to find 'src' in any of
907/// the parent loop's iter args.
908static bool sameOrEquivalentIterArg(Value src, Value dst) {
909 // Stack like vector containing possible iterArgs candidates. The first one
910 // is dst, and we will transverse the IR from there.
911 SmallVector<Value> destWorklist;
912 destWorklist.push_back(dst);
913
914 while (!destWorklist.empty()) {
915 Value currentDst = destWorklist.pop_back_val();
916
917 // We have found the same operand in some iter arg in the loop structure,
918 // so src and dst are equivalent.
919 if (src == currentDst)
920 return true;
921
922 // The operands are not equivalent, look for enclosing loops over
923 // currentDst.
924 auto bbArg = dyn_cast<BlockArgument>(currentDst);
925 if (!bbArg)
926 continue;
927
928 Block *parentBlock = bbArg.getOwner();
929 assert(parentBlock && "unlinked block argument");
930
931 Operation *parentOp = parentBlock->getParentOp();
932 assert(parentOp && "expected block argument with parent operation");
933
934 // Check if parent is loop-like. If it's not, do not add it to the worklist.
935 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
936 if (!parentLoop)
937 continue;
938
939 for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
940 // No need to check for null as innerIterArg is tied to parentLoop.
941 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
942 Value loopBlockArgument =
943 parentLoop->getOperand(operand->getOperandNumber());
944 destWorklist.push_back(loopBlockArgument);
945 }
946 }
947
948 return false;
949}
950
951/// Find the first "extract" user of `producerOp` and tile it right before its
952/// use. The tiled op is fused under the `containingOp`.
953/// Return this fused op on success or nullptr if anything fails.
954/// If tiled op has uses that are dominated by `containingOp`, return
955/// a new `containingOp` with results of the fused op appended to
956/// results of the `containingOp` or nullptr if there are no dominated uses.
957static std::tuple<SmallVector<Operation *>, Operation *>
959 Operation *producerOp, Operation *containingOp) {
960 LDBG() << "Try to fuse a direct extract use";
961 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
962 if (!tileableProducer) {
963 diag.attachNote(producerOp->getLoc())
964 << "producer is not a TileableInterface: " << *producerOp;
965 return {};
966 }
967
968 // Search the producer slices accessed within the containing operation.
969 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
970 // evolve into an interface.
971 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
972 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
973 return sliceOp && containingOp->isProperAncestor(sliceOp);
974 });
975
976 // Find a fusion opportunity.
977 if (it == tileableProducer->getUsers().end()) {
978 diag.attachNote(tileableProducer->getLoc())
979 << "could not find fusion opportunity for: " << *tileableProducer;
980 return {};
981 }
982 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
983
984 // Try to fuse the producer in-place.
985 OpBuilder::InsertionGuard guard(rewriter);
986 rewriter.setInsertionPoint(sliceOpToTile);
987
988 // Clone the producer inside the consumer and try to update the producer init
989 // operands using the loop bbArgs if applicable. More precisely, if the bbArg
990 // of the container loop points to a value that it is used by the consumer op,
991 // then, instead of using such value on the consumer, use the value coming
992 // from the bbArg instead. This allows to reuse the output tensor (instead of
993 // creating a new one) of the container when both producer and container write
994 // to the same output.
995 if (LoopLikeOpInterface containerLoop =
996 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
997 Operation *clone = rewriter.clone(*producerOp);
998 rewriter.modifyOpInPlace(clone, [&]() {
999 // Iterate over the outputs of the producer and over the loop bbArgs and
1000 // check if any bbArg points to the same value as the producer output. In
1001 // such case, make the producer output point to the bbArg directly.
1002 auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone);
1003 if (!dpsInterface)
1004 return;
1005
1006 for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
1007 Value producerOperand =
1008 clone->getOperand(initOperandPtr.getOperandNumber());
1009 for (BlockArgument containerIterArg :
1010 containerLoop.getRegionIterArgs()) {
1011 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
1012 Value consumerOperand =
1013 containerLoop->getOperand(bbArg->getOperandNumber());
1014 // The producer has the same init as the loop bbArg, use it.
1015 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
1016 initOperandPtr.set(containerIterArg);
1017 }
1018 }
1019 }
1020 });
1021
1022 tileableProducer = dyn_cast<TilingInterface>(clone);
1023 }
1024
1025 // Tile the producer.
1026 int64_t resultNumber =
1027 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
1028 LDBG() << "resultNumber: " << resultNumber;
1029
1030 SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
1031 SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
1032
1033 FailureOr<TilingResult> tileAndFuseResult =
1034 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
1035 sizes);
1036
1037 if (failed(tileAndFuseResult)) {
1038 diag.attachNote(tileableProducer->getLoc())
1039 << "failed to tile producer op: " << *tileableProducer;
1040 return {};
1041 }
1042
1043#ifndef NDEBUG
1044 for (auto *tiledOp : tileAndFuseResult->tiledOps) {
1045 LDBG() << "tiledProducer: " << *tiledOp;
1046 }
1047#endif
1048
1049 // Replace the extract op.
1050 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1051 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1052 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1053 if (failed(maybeRankReduced)) {
1054 diag.attachNote(producerOp->getLoc())
1055 << "shape types don't match (missing canonicalization?):\nTiledOp: "
1056 << tileAndFuseResult->tiledValues[0]
1057 << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
1058 return {};
1059 }
1060 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
1061
1062 // Add new outputs to containing op, if required
1063 Operation *newContainingOp = replaceForAllWithNewSignature(
1064 rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
1065 resultNumber, offsets, sizes);
1066
1067 // Cleanup clone.
1068 if (isa<LoopLikeOpInterface>(containingOp))
1069 rewriter.eraseOp(tileableProducer);
1070
1071 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
1072}
1073
1074/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
1075/// it is exactly the `containingOp`, otherwise bail.
1076/// Then, find the first "extract" user of the tied block argument and tile it
1077/// right before its "extract" use. The tiled op is fused under the
1078/// `containingOp`.
1079/// Return this fused op on success or nullptr if anything fails.
1082 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
1083 Operation *containingOp) {
1084 LDBG() << "Try to fuse an extract use through block argument";
1085
1086 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
1087 if (!tileableProducer) {
1088 diag.attachNote(producerOp->getLoc())
1089 << "producer is not a TileableInterface: " << *producerOp;
1090 return {};
1091 }
1092
1093 // Search the first use by a "scf::ForallOp" user.
1094 scf::ForallOp forallOp;
1095 auto itProducerUses =
1096 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
1097 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
1098 return forallOp;
1099 });
1100 // If it's not from the containing op, return.
1101 if (!forallOp || forallOp != containingOp) {
1102 diag.attachNote(tileableProducer->getLoc())
1103 << "could not find a use by the containing op: " << *tileableProducer;
1104 return {};
1105 }
1106
1107 // Search the producer slices accessed within the containing
1108 // operation.
1109 // TODO: Generalize to more extract/insert/parallel_insert triples.
1110 // Maybe evolve into an interface.
1111 OpOperand *pUse = &(*itProducerUses);
1112 BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
1113
1114 // Search the producer slices accessed within the containing operation.
1115 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
1116 // evolve into an interface.
1117 auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
1118 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1119 return sliceOp && containingOp->isProperAncestor(sliceOp);
1120 });
1121
1122 // Find a fusion opportunity.
1123 if (itBBArgUsers == bbArg.getUsers().end()) {
1124 diag.attachNote(containingOp->getLoc())
1125 << "could not find fusion opportunity for bbArg: " << bbArg;
1126 return {};
1127 }
1128 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
1129
1130 // Try to fuse the producer in-place.
1131 OpBuilder::InsertionGuard guard(rewriter);
1132 rewriter.setInsertionPoint(sliceOpToTile);
1133
1134 // Replace the use in the tileableProducer before tiling: clone, replace and
1135 // then tile.
1136 int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
1137 LDBG() << "resultNumber: " << resultNumber;
1138
1139 // Gather destination tensors.
1140 SmallVector<Value> destinationTensors;
1142 rewriter, tileableProducer->getLoc(), tileableProducer,
1143 destinationTensors))) {
1144 diag.attachNote(tileableProducer->getLoc())
1145 << "failed to get destination tensors for: " << *tileableProducer;
1146 return {};
1147 }
1148
1149 IRMapping bvm;
1150 bvm.map(destinationTensors[resultNumber], bbArg);
1151 auto tileableProducerClone =
1152 cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
1153 llvm::scope_exit scopeGuard(
1154 [&]() { rewriter.eraseOp(tileableProducerClone); });
1155
1156 // Tile the producer.
1157 FailureOr<TilingResult> tileAndFuseResult =
1158 tileableProducerClone.generateResultTileValue(
1159 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
1160 sliceOpToTile.getMixedSizes());
1161 if (failed(tileAndFuseResult)) {
1162 diag.attachNote(tileableProducer->getLoc())
1163 << "failed to tile producer op: " << *tileableProducer;
1164 return {};
1165 }
1166
1167 // Replace the extract op.
1168 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1169 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1170 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1171 assert(succeeded(maybeRankReduced) && "unexpected shape");
1172 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
1173
1174 // Replace the use in containingOp.
1175 rewriter.modifyOpInPlace(containingOp, [&]() {
1176 containingOp->setOperand(pUse->getOperandNumber(),
1177 destinationTensors.front());
1178 });
1179
1180 return tileAndFuseResult->tiledOps;
1181}
1182
1184 Operation *producerOp,
1185 Operation *containingOp) {
1186 LDBG() << "Try to fuse an use by cloning";
1187
1188 // Gather all uses inside the containing op.
1190 for (OpResult result : producerOp->getOpResults()) {
1191 for (OpOperand &use : result.getUses()) {
1192 if (containingOp->isProperAncestor(use.getOwner())) {
1193 uses.push_back(&use);
1194 continue;
1195 }
1196 // Cannot clone and fuse if the use is by the containing op itself: fail
1197 // immediately.
1198 if (containingOp == use.getOwner()) {
1199 diag.attachNote(producerOp->getLoc())
1200 << "producer op use by containing op cannot be fused by cloning";
1201 return nullptr;
1202 }
1203 }
1204 }
1205
1206 // Check for a non-empty list of fusion opportunities.
1207 if (uses.empty()) {
1208 diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
1209 return nullptr;
1210 }
1211
1212 // Clone and fuse inside the containing op.
1213 Operation *fusedOp = nullptr;
1214 OpOperand *use = uses.front();
1215 // Parallel insert slice is not a valid clone destination.
1216 // TODO: Generalize to other type of ops.
1217 assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
1218 "Parallel insert slice is not a valid clone destination");
1219 unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
1220 LDBG() << "resultNumber: " << resultNumber;
1221
1222 OpBuilder::InsertionGuard guard(rewriter);
1223 rewriter.setInsertionPoint(use->getOwner());
1224 fusedOp = rewriter.clone(*producerOp);
1225 rewriter.modifyOpInPlace(
1226 use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1227
1228 return fusedOp;
1229}
1230
1231bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1232 // Allow repeated handles since we are fusing everything anyway.
1233 return true;
1234}
1235
1237transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
1240 SmallVector<Operation *> fusedOps;
1241 auto producerOps = state.getPayloadOps(getProducerOp());
1242 auto containingOps = state.getPayloadOps(getContainingOp());
1243 if (!llvm::hasSingleElement(containingOps)) {
1244 return emitDefiniteFailure()
1245 << "requires exactly one containing_op handle (got "
1246 << llvm::range_size(containingOps) << ")";
1247 }
1248 Operation *containingOp = *containingOps.begin();
1249
1250 // If nothing to fuse, propagate success.
1251 if (std::empty(producerOps)) {
1252 results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
1253 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1255 }
1256
1257 // Helper function to find the next producer that should be fused. Take any
1258 // producer that has a use inside the containing op.
1259 SetVector<Operation *> remainingProducers(llvm::from_range, producerOps);
1260 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1261 for (const auto &it : enumerate(remainingProducers)) {
1262 Operation *producerOp = it.value();
1263 // The containing op may be a user of producerOp: use isAncestor.
1264 int64_t numUsesInContainingOp =
1265 llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
1266 return containingOp->isAncestor(op);
1267 });
1268 // TODO: When resolving the TODO below (no duplicate ops), take an op
1269 // that has no use among the remaining producers. This is a topological
1270 // sorting.
1271 if (numUsesInContainingOp > 0) {
1272 if (numUsesInContainingOp == 1)
1273 remainingProducers.erase(remainingProducers.begin() + it.index());
1274 return producerOp;
1275 }
1276 }
1277 return failure();
1278 };
1279
1280 while (!remainingProducers.empty()) {
1281 auto nextProducer = getNextProducer();
1282 if (failed(nextProducer)) {
1283 auto diag = mlir::emitSilenceableFailure(getLoc())
1284 << "could not find next producer to fuse into container";
1285 diag.attachNote(containingOp->getLoc()) << "containing op";
1286 return diag;
1287 }
1288
1289 Operation *producerOp = *nextProducer;
1290
1291 // Default diagnostic, to be complemented with more failure information.
1293 diag << "could not fuse " << *producerOp << " into " << *containingOp;
1294
1295 // TODO: If there are multiple uses of the producer in the containing op,
1296 // we currently tile/clone the op multiple times (once per use). In some
1297 // cases, we can tile/clone once and reuse the value for each use.
1298 // Futhermore, producers should then be traversed according to a
1299 // topological sorting.
1300 auto [tiledOps, newContainingOp] =
1301 tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1302 if (!tiledOps.empty()) {
1303 LDBG() << "\nFused a direct extract use\n" << *containingOp;
1304 fusedOps.append(tiledOps);
1305 if (newContainingOp) {
1306 // Update handles associated with the containing op so we don't need to
1307 // invalidate them. This is a hack to support better composability
1308 // between tiling and fusion while a proper mechanism is being
1309 // investigated.
1310 //
1311 // DO NOT replicate this elsewhere unless you understand what you are
1312 // doing.
1313 LogicalResult replacementStatus =
1314 rewriter.notifyPayloadOperationReplaced(containingOp,
1315 newContainingOp);
1316 (void)replacementStatus;
1317 assert(succeeded(replacementStatus) &&
1318 "unable to update transform state mapping");
1319 rewriter.eraseOp(containingOp);
1320 containingOp = newContainingOp;
1321 }
1322 continue;
1323 }
1324
1325 SmallVector<Operation *> tiledContainingOpOperand =
1327 rewriter, diag, producerOp, containingOp);
1328 if (!tiledContainingOpOperand.empty()) {
1329 LDBG() << "\nFused an extract use through block argument\n"
1330 << *containingOp;
1331 fusedOps.append(tiledContainingOpOperand);
1332 continue;
1333 }
1334
1335 Operation *cloned =
1336 cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1337 if (cloned) {
1338 LDBG() << "\nFused an use by cloning\n" << *containingOp;
1339 fusedOps.push_back(cloned);
1340 continue;
1341 }
1343 }
1344
1345 results.set(cast<OpResult>(getFusedOp()), fusedOps);
1346 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1348}
1349
1350void transform::FuseIntoContainingOp::getEffects(
1352 consumesHandle(getProducerOpMutable(), effects);
1353 onlyReadsHandle(getContainingOpMutable(), effects);
1354 producesHandle(getOperation()->getOpResults(), effects);
1355 modifiesPayload(effects);
1356}
1357
1358//===----------------------------------------------------------------------===//
1359// GeneralizeOp
1360//===----------------------------------------------------------------------===//
1361
1363transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1364 LinalgOp target,
1367 // Exit early if no transformation is needed.
1368 if (isa<GenericOp>(target)) {
1369 results.push_back(target);
1371 }
1372 rewriter.setInsertionPoint(target);
1373 FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1374 if (succeeded(generic)) {
1375 results.push_back(generic->getOperation());
1377 }
1378 return emitDefaultSilenceableFailure(target);
1379}
1380
1381//===----------------------------------------------------------------------===//
1382// SpecializeOp
1383//===----------------------------------------------------------------------===/
1384
1386transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1387 LinalgOp target,
1390 // Exit early if the operation is not a generic.
1391 if (!isa<GenericOp>(target)) {
1392 results.push_back(target);
1394 }
1395 rewriter.setInsertionPoint(target);
1397 opts.emitCategoryOps = getEmitCategory();
1398 FailureOr<LinalgOp> named =
1399 specializeGenericOp(rewriter, cast<GenericOp>(target), opts);
1400 if (succeeded(named)) {
1401 results.push_back(named->getOperation());
1403 }
1404 return emitDefaultSilenceableFailure(target);
1405}
1406
1407//===----------------------------------------------------------------------===//
1408// InterchangeOp
1409//===----------------------------------------------------------------------===//
1410
1412transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1413 GenericOp target,
1416 ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1417 // Exit early if no transformation is needed.
1418 if (interchangeVector.empty()) {
1419 results.push_back(target);
1421 }
1422
1423 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1424 if (interchangeVector.size() != numLoops) {
1425 return emitSilenceableError()
1426 << getIteratorInterchangeAttrName() << " has length ("
1427 << interchangeVector.size()
1428 << ") different from the number of loops in the target operation ("
1429 << numLoops << ")";
1430 }
1431 FailureOr<GenericOp> res = interchangeGenericOp(
1432 rewriter, target, SmallVector<unsigned>(interchangeVector));
1433 if (failed(res))
1434 return emitDefiniteFailure() << "failed to apply";
1435 results.push_back(res->getOperation());
1437}
1438
1439LogicalResult transform::InterchangeOp::verify() {
1440 ArrayRef<int64_t> permutation = getIteratorInterchange();
1441 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1442 if (!std::is_permutation(sequence.begin(), sequence.end(),
1443 permutation.begin(), permutation.end())) {
1444 return emitOpError()
1445 << "expects iterator_interchange to be a permutation, found "
1446 << getIteratorInterchange();
1447 }
1448 return success();
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// LinalgCopyToMemrefOp
1453//===----------------------------------------------------------------------===//
1454
1455DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1456 transform::TransformRewriter &rewriter, Operation *targetOp,
1459
1460 // Check if the target can be converted.
1461 if (!isa<linalg::CopyOp>(targetOp)) {
1463 emitSilenceableError() << "only linalg.copy target ops are supported";
1464 diag.attachNote(targetOp->getLoc()) << "target op";
1465 return diag;
1466 }
1467
1468 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1469 if (!copyOp.hasPureBufferSemantics()) {
1471 emitSilenceableError()
1472 << "cannot transform a linalg.copy on tensors into a memref.copy";
1473 diag.attachNote(targetOp->getLoc()) << "target op";
1474 return diag;
1475 }
1476
1477 SmallVector<Value> inputs = copyOp.getInputs();
1478 SmallVector<Value> outputs = copyOp.getOutputs();
1479 assert(inputs.size() == 1 && "expected linalg copy op with one input");
1480 assert(outputs.size() == 1 && "expected memref copy op with one output");
1481 Value input = inputs.front();
1482 Value output = outputs.front();
1483
1484 // linalg.copy supports different element types on source/dest whereas
1485 // memref.copy does not, so we must check that the source and dest types can
1486 // be handled by memref.copy and otherwise reject the transformation.
1487 if (!isa<ShapedType>(input.getType())) {
1489 emitSilenceableError()
1490 << "cannot transform a linalg.copy which input has no shape";
1491 diag.attachNote(targetOp->getLoc()) << "target op";
1492 return diag;
1493 }
1494
1495 // linalg.copy destination must be a shaped type.
1496 assert(isa<ShapedType>(output.getType()));
1497
1498 if (cast<ShapedType>(input.getType()).getElementType() !=
1499 cast<ShapedType>(output.getType()).getElementType()) {
1501 emitSilenceableError()
1502 << "cannot transform a linalg.copy with different source and "
1503 "destination element types ";
1504 diag.attachNote(targetOp->getLoc()) << "target op";
1505 return diag;
1506 }
1507
1508 // Target can be converted, do it.
1509 auto memrefCopyOp =
1510 rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1511
1512 results.push_back(memrefCopyOp);
1514}
1515
1516//===----------------------------------------------------------------------===//
1517// LowerPackOp
1518//===----------------------------------------------------------------------===//
1519
1520DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1521 transform::TransformRewriter &rewriter, linalg::PackOp target,
1522 transform::ApplyToEachResultList &transformResults,
1524 rewriter.setInsertionPoint(target);
1525 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1526 FailureOr<LowerPackResult> res =
1527 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1528 if (failed(res)) {
1529 return mlir::emitSilenceableFailure(target->getLoc())
1530 << "cannot lower to pad + expand + transpose";
1531 }
1532 transformResults.push_back(res->padOp);
1533 transformResults.push_back(res->expandShapeOp);
1534 transformResults.push_back(res->transposeOp);
1536}
1537
1538//===----------------------------------------------------------------------===//
1539// LowerUnPackOp
1540//===----------------------------------------------------------------------===//
1541
1542DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1543 transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1544 transform::ApplyToEachResultList &transformResults,
1546 rewriter.setInsertionPoint(target);
1547 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1548 FailureOr<LowerUnPackOpResult> res =
1549 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1550 if (failed(res)) {
1552 emitSilenceableError()
1553 << "cannot lower to transpose + collapse + extract";
1554 diag.attachNote(target->getLoc()) << "target payload op";
1555 return diag;
1556 }
1557 transformResults.push_back(res->emptyOp);
1558 transformResults.push_back(res->transposeOp);
1559 transformResults.push_back(res->collapseShapeOp);
1560 transformResults.push_back(res->extractSliceOp);
1561 transformResults.push_back(res->copyOp);
1563}
1564
1565//===---------------------------------------------------------------------===//
1566// MatchOp
1567//===---------------------------------------------------------------------===//
1568
1569void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1570 Value target, ArrayRef<StringRef> opNames) {
1571 result.addOperands(target);
1572 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1573 builder.getStrArrayAttr(opNames));
1574 result.addTypes(transform::AnyOpType::get(builder.getContext()));
1575}
1576
1577void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1578 TypeRange resultTypes, Value target,
1579 ArrayRef<StringRef> opNames) {
1580 result.addOperands(target);
1581 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1582 builder.getStrArrayAttr(opNames));
1583 result.addTypes(resultTypes);
1584}
1585
1587transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1590 llvm::StringSet<> strs;
1591 if (getOps().has_value())
1592 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1593
1594 auto payloadOps = state.getPayloadOps(getTarget());
1595 if (!llvm::hasSingleElement(payloadOps)) {
1596 return emitDefiniteFailure("requires exactly one target handle");
1597 }
1598
1600 bool incorrectNumOperandTypes = false;
1601 auto matchFun = [&](Operation *op) {
1602 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1603 return;
1604
1605 // Interfaces cannot be matched by name, just by ID.
1606 // So we specifically encode the interfaces we care about for this op.
1607 if (getInterface().has_value()) {
1608 auto iface = getInterface().value();
1609 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1610 !isa<LinalgOp>(op))
1611 return;
1612 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1613 !isa<TilingInterface>(op))
1614 return;
1615 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1616 !isa<LoopLikeOpInterface>(op))
1617 return;
1618 }
1619
1620 // Check if all specified attributes match.
1621 if (getOpAttrs().has_value()) {
1622 DictionaryAttr opAttrs = getOpAttrs().value();
1623 for (NamedAttribute attr : opAttrs) {
1624 if (attr.getName() == getInterfaceAttrName() ||
1625 attr.getName() == getOpsAttrName())
1626 continue;
1627 if (!op->hasAttr(attr.getName()))
1628 return;
1629 if (op->getAttr(attr.getName()) != attr.getValue())
1630 return;
1631 }
1632 }
1633
1634 if (getFilterResultType().has_value()) {
1635 Type t = getFilterResultType().value();
1636 if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1637 return;
1638 }
1639
1640 if (getFilterOperandTypes().has_value()) {
1641 mlir::ArrayAttr types = getFilterOperandTypes().value();
1642 auto operandTypes = op->getOperandTypes();
1643
1644 if (types.size() == 1) {
1645 // All the operands must must be equal to the specified type
1646 auto typeattr =
1647 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1648 Type t = cast<::mlir::Type>(typeattr.getValue());
1649 if (!llvm::all_of(op->getOperandTypes(),
1650 [&](Type operandType) { return operandType == t; }))
1651 return;
1652 } else {
1653 // The operand types must match all the types in the list (in the same
1654 // order in with they are specified)
1655 if (types.size() != operandTypes.size()) {
1656 incorrectNumOperandTypes = true;
1657 return;
1658 }
1659
1660 for (auto [attr, operandType] :
1661 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1662 auto typeattr = cast<mlir::TypeAttr>(attr);
1663 Type type = cast<::mlir::Type>(typeattr.getValue());
1664
1665 if (type != operandType)
1666 return;
1667 }
1668 }
1669 }
1670
1671 // All constraints are satisfied.
1672 res.push_back(op);
1673 return;
1674 };
1675
1676 (*payloadOps.begin())->walk(matchFun);
1677 if (incorrectNumOperandTypes)
1678 return emitDefiniteFailure("If filter_operand_types contains more than a "
1679 "type, then it must contain as much types as "
1680 "the number of operands in the target ops");
1681 results.set(cast<OpResult>(getResult()), res);
1683}
1684
1685//===---------------------------------------------------------------------===//
1686// MultiTileSizesOp
1687//===---------------------------------------------------------------------===//
1688
1690 Type targetType, Type lowSizeType, Type,
1691 Type) {
1692 printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1693}
1694
1695static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1696 Type &targetType, Type &lowSizeType,
1697 Type &highSizeType,
1698 Type &splitPointType) {
1699 FunctionType funcType;
1700 llvm::SMLoc typeLoc = parser.getCurrentLocation();
1701 if (failed(parser.parseType<FunctionType>(funcType)))
1702 return failure();
1703
1704 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1705 parser.emitError(typeLoc) << "expects a trailing functional type with one "
1706 "argument and one result";
1707 }
1708 targetType = funcType.getInput(0);
1709 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1710
1711 return success();
1712}
1713
1714DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1715 transform::TransformRewriter &rewriter, LinalgOp target,
1717 if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1718 if (target.hasDynamicShape()) {
1719 auto diag = emitSilenceableError()
1720 << "cannot compute parametric tile sizes for dynamically "
1721 "shaped payload op";
1722 diag.attachNote(target->getLoc()) << "payload op";
1723 return diag;
1724 }
1725
1726 FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1727 target, getDimension(), getTargetSize(), getDivisor());
1728 if (failed(spec)) {
1729 return emitSilenceableError()
1730 << "failed to compute multi-size tiling sizes";
1731 }
1732
1733 Builder builder(target.getContext());
1734 results.assign(llvm::map_range(
1735 ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1736 spec->lowTileSize * spec->lowTripCount}),
1737 [&builder, this](int64_t value) {
1738 return builder.getIntegerAttr(
1739 cast<ParamType>(getLowSize().getType()).getType(), value);
1740 }));
1742 }
1743
1744 OpBuilder builder(target.getContext());
1745 builder.setInsertionPoint(target);
1746 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1747 OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1748 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1749 builder, target, getDimension(), targetSize, divisor);
1750 if (failed(spec)) {
1751 return emitSilenceableError() << "could not generate tile size computation";
1752 }
1753
1754 AffineExpr s0 = builder.getAffineSymbolExpr(0);
1755 AffineExpr s1 = builder.getAffineSymbolExpr(1);
1756 Operation *splitPoint =
1757 affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1758 {spec->lowTileSize, spec->lowTripCount});
1759 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1760 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1761 assert(lowTileSize && highTileSize && splitPoint &&
1762 "tile sizes are not produced by operations");
1763 results.reserve(results.size() + 3);
1764 results.push_back(lowTileSize);
1765 results.push_back(highTileSize);
1766 results.push_back(splitPoint);
1768}
1769
1770void transform::MultiTileSizesOp::getEffects(
1772 onlyReadsHandle(getTargetMutable(), effects);
1773 producesHandle(getOperation()->getOpResults(), effects);
1774 if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1775 onlyReadsPayload(effects);
1776 else
1777 modifiesPayload(effects);
1778}
1779
1780LogicalResult transform::MultiTileSizesOp::verify() {
1781 if (getLowSize().getType() != getHighSize().getType() ||
1782 getLowSize().getType() != getSplitPoint().getType()) {
1783 return emitOpError() << "expects all results type to be the same";
1784 }
1785 return success();
1786}
1787
1788//===---------------------------------------------------------------------===//
1789// PackOp
1790//===---------------------------------------------------------------------===//
1791
1792void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1793 Value target,
1794 ArrayRef<OpFoldResult> mixedPackedSizes) {
1795 SmallVector<int64_t> staticPackedSizes;
1796 SmallVector<Value> dynamicPackedSizes;
1797 dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1798 staticPackedSizes);
1799 // Call the default builder which sets up the proper operands segment sizes
1800 // attributes for multiple variadic operands. In the absence of this, horrible
1801 // bugs ensue.
1802 Type linalgOpHType = transform::OperationType::get(
1803 builder.getContext(), GenericOp::getOperationName());
1804 build(builder, result,
1805 /*resultType=*/linalgOpHType,
1806 /*target=*/target,
1807 /*dynamic_sizes=*/dynamicPackedSizes,
1808 /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1809}
1810
1811SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1812 Builder b(getContext());
1813 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1814}
1815
1817transform::PackOp::apply(transform::TransformRewriter &rewriter,
1818 transform::TransformResults &transformResults,
1820 auto targetOps = state.getPayloadOps(getTarget());
1821 // If nothing to pack, propagate success.
1822 if (std::empty(targetOps)) {
1823 transformResults.set(cast<OpResult>(getPackedOp()),
1826 }
1827 // Fail on multi-op handles.
1828 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1829 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1830 return emitSilenceableError()
1831 << "requires target to map to exactly 1 LinalgOp (got "
1832 << llvm::range_size(targetOps) << ")";
1833 }
1834 // Fail on mismatched number of pack sizes.
1835 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1836 return emitSilenceableError()
1837 << "requires number of packed sizes match the number of loops ("
1838 << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1839 << ")";
1840 }
1841
1842 // Unpack handles to constants or actual SSA index values.
1843 SmallVector<OpFoldResult> packedSizes;
1845 state, *this, packedSizes, getMixedPackedSizes());
1846
1847 rewriter.setInsertionPoint(linalgOp);
1848 FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1849 if (failed(maybeResult))
1850 return emitDefiniteFailure("data tiling failed");
1851
1852 transformResults.set(cast<OpResult>(getPackedOp()),
1853 {maybeResult->packedLinalgOp.getOperation()});
1855}
1856
1857void transform::PackOp::getEffects(
1859 transform::consumesHandle(getTargetMutable(), effects);
1860 transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1861 transform::producesHandle(getOperation()->getOpResults(), effects);
1863}
1864
1865//===---------------------------------------------------------------------===//
1866// PackGreedilyOp.
1867//===---------------------------------------------------------------------===//
1868
1869LogicalResult transform::PackGreedilyOp::verify() {
1870 if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1871 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1872 << " is not a valid permutation";
1873 }
1874 // TODO: relax to allow empty once we have another strategy than just matmul.
1875 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1876 for (auto [s, nmo] :
1877 llvm::zip_equal(getMixedMatmulPackedSizes(),
1878 getMatmulPaddedSizesNextMultipleOf())) {
1879 std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1880 if (nmo != 0 &&
1881 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1882 return emitOpError() << "at most one of the packed_size and the "
1883 "padded_sizes_next_multiple_of can be nonzero "
1884 "for the matmul strategy";
1885 }
1886 }
1887 }
1888 return success();
1889}
1890
1892PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1893 transform::TransformResults &transformResults,
1896 for (Operation *op : state.getPayloadOps(getTarget())) {
1897 auto linalgOp = dyn_cast<LinalgOp>(op);
1898 if (!linalgOp)
1899 continue;
1900 // linalgOp will be replaced and the insertion point may be invalidated if
1901 // we set it before -> set it after.
1902 rewriter.setInsertionPointAfter(linalgOp);
1903 // Failing to pack greedily is perfectly fine.
1904 // In the future we will want to order packings according to some metric.
1905 FailureOr<PackResult> packResult = packMatmulGreedily(
1906 /*rewriter=*/rewriter,
1907 /*linalgOp=*/linalgOp,
1908 /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1909 /*mnkPaddedSizesNextMultipleOf=*/
1910 getMatmulPaddedSizesNextMultipleOf(),
1911 /*mnkOrder=*/getMatmulInnerDimsOrder());
1912 if (succeeded(packResult)) {
1913 results.push_back(packResult->packedLinalgOp);
1914 continue;
1915 }
1916 results.push_back(linalgOp);
1917 }
1918 transformResults.set(cast<OpResult>(getPackedOp()), results);
1920}
1921
1922SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1923 Builder b(getContext());
1924 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1925 b);
1926}
1927
1928void transform::PackGreedilyOp::getEffects(
1930 transform::consumesHandle(getTargetMutable(), effects);
1931 transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1932 transform::producesHandle(getOperation()->getOpResults(), effects);
1934}
1935
1936//===---------------------------------------------------------------------===//
1937// PackTransposeOp
1938//===---------------------------------------------------------------------===//
1939
1940LogicalResult transform::PackTransposeOp::verify() {
1941 if (!isPermutationVector(getInnerPerm())) {
1942 return emitOpError() << getInnerPermAttrName()
1943 << " is not a valid permutation";
1944 }
1945 if (!isPermutationVector(getOuterPerm())) {
1946 return emitOpError() << getOuterPermAttrName()
1947 << " is not a valid permutation";
1948 }
1949 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1950 return emitOpError() << " at least one of " << getInnerPermAttrName()
1951 << " or " << getOuterPermAttrName()
1952 << " must be specified";
1953 }
1954 return success();
1955}
1956
1957namespace {
1958enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1959} // namespace
1960
1961/// Return true if `permutation` is a valid permutation of the
1962/// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1963/// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1964/// This is the case when the `permutation` rank matches the rank expected by
1965/// `op` and `permutation` is itself a permutation vector.
1966/// Return true if either `op` or `permutation` are empty to allow a simpler
1967/// polymorphic implementation.
1968template <typename RelayoutOpTy>
1969static bool isValidPackingPermutation(
1970 RelayoutOpTy op, ArrayRef<int64_t> permutation,
1971 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1972 static_assert(
1973 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1974 "applies to only pack or unpack operations");
1975 if (!op || permutation.empty())
1976 return true;
1977 size_t innerRank = op.getInnerDimsPos().size();
1978 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1979 return permutation.size() == innerRank && isPermutationVector(permutation);
1980 // op.getOuterDimsPerm() may be empty, in which case it is identity.
1981 // Don't rely on it.
1982 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1983 return permutation.size() == op.getSourceRank() &&
1984 isPermutationVector(permutation);
1985 }
1986 return permutation.size() == op.getDestRank() &&
1987 isPermutationVector(permutation);
1988}
1989
1991transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1992 transform::TransformResults &transformResults,
1994 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1995 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1996 // Step 1. If nothing to pack, propagate success.
1997 if (std::empty(packOrUnpackOps)) {
1998 transformResults.set(cast<OpResult>(getPackedOp()), {});
1999 transformResults.set(cast<OpResult>(getPackOp()), {});
2000 transformResults.set(cast<OpResult>(getUnPackOp()), {});
2002 }
2003
2004 // Step 2. Bunch of runtime sanity check and error messages.
2005 // Step 2.1. Fail on multi-op handles.
2006 if (!llvm::hasSingleElement(packOrUnpackOps) ||
2007 !llvm::hasSingleElement(linalgOps)) {
2008 return emitSilenceableError()
2009 << "requires target to map to exactly 1 "
2010 "packing op and 1 packed op ("
2011 << "got " << llvm::range_size(packOrUnpackOps) << " and "
2012 << llvm::range_size(linalgOps) << ")";
2013 }
2014
2015 // Step 2.2. Fail on wrong type.
2016 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
2017 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
2018 if ((!packOp && !unPackOp)) {
2019 return emitSilenceableError() << "requires target to map to a "
2020 "linalg.pack or linalg.unpack";
2021 }
2022 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
2023 if (!linalgOpTarget)
2024 return emitSilenceableError() << "requires a LinalgOp target";
2025
2026 // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
2027 LinalgOp linalgOp;
2028 if (packOp && packOp.getResult().hasOneUse())
2029 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
2030 else if (unPackOp)
2031 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
2032 if (linalgOp != linalgOpTarget) {
2033 auto errorMsg =
2034 packOp ? StringLiteral{"not a single use by the LinalgOp target"}
2035 : StringLiteral{"not produced by the LinalgOp target"};
2036 return emitSilenceableError() << errorMsg;
2037 }
2038
2039 // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
2040 // PackOp.
2041 if (unPackOp) {
2042 assert(!packOp && "packOp must be null on entry when unPackOp is not null");
2043 OpOperand *packUse = linalgOp.getDpsInitOperand(
2044 cast<OpResult>(unPackOp.getSource()).getResultNumber());
2045 packOp = packUse->get().getDefiningOp<linalg::PackOp>();
2046 if (!packOp || !packOp.getResult().hasOneUse())
2047 return emitSilenceableError() << "could not find matching pack op";
2048 }
2049
2050 // Step 2.5. Fail if any permutation does not validate.
2051 for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
2052 ArrayRef<int64_t> perm =
2053 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
2054 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
2055 ? StringLiteral{"invalid outer_perm"}
2056 : StringLiteral{"invalid inner_perm"};
2057 if (!isValidPackingPermutation(packOp, perm, permType) ||
2058 !isValidPackingPermutation(unPackOp, perm, permType)) {
2059 Operation *packOrUnpackOp =
2060 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
2061 return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
2062 }
2063 }
2064
2065 // From here on, packOp and linalgOp are always present, unPackOp may or may
2066 // not be present.
2067 assert(packOp && linalgOp && "unexpected null op");
2068
2069 // Step 3. Actually transpose the ops.
2070 FailureOr<PackTransposeResult> res = packTranspose(
2071 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
2072 // Preconditions have been checked, it is an error to fail here.
2073 assert(succeeded(res) && "unexpected packTranspose failure");
2074
2075 // Step 4. Return results.
2076 transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
2077 transformResults.set(cast<OpResult>(getPackedOp()),
2078 {res->transposedLinalgOp});
2079 if (unPackOp) {
2080 transformResults.set(cast<OpResult>(getUnPackOp()),
2081 {res->transposedUnPackOp});
2082 } else {
2083 transformResults.set(cast<OpResult>(getUnPackOp()), {});
2084 }
2085
2087}
2088
2089//===---------------------------------------------------------------------===//
2090// PadOp
2091//===---------------------------------------------------------------------===//
2092
2093void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
2094 ArrayRef<int64_t> paddingDimensions,
2095 ArrayRef<int64_t> padToMultipleOf,
2096 ArrayRef<int64_t> nofoldFlags,
2097 ArrayRef<Attribute> transposePaddings,
2098 StringRef copyBackOp,
2099 bool usePrescribedTensorShapes) {
2100 auto resultType = transform::AnyOpType::get(b.getContext());
2101 return build(/*odsBuilder=*/b,
2102 /*result=*/result,
2103 /*types=*/TypeRange{resultType, resultType},
2104 /*target=*/target,
2105 /*padding_values=*/ArrayAttr(), // let inference handle this
2106 /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
2107 /*pad_to_multiple_of=*/ValueRange{},
2108 /*padToMultipleOf=*/
2109 (padToMultipleOf.empty()
2111 : b.getDenseI64ArrayAttr(padToMultipleOf)),
2112 /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
2113 /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
2114 /*copy_back_op=*/b.getStringAttr(copyBackOp),
2115 /*use_prescribed_tensor_shapes=*/
2116 usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
2117}
2118
2119void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
2120 ArrayRef<int64_t> paddingDimensions,
2121 ArrayRef<OpFoldResult> mixedPadToMultipleOf,
2122 ArrayRef<int64_t> nofoldFlags,
2123 ArrayRef<Attribute> transposePaddings,
2124 StringRef copyBackOp,
2125 bool usePrescribedTensorShapes) {
2126 auto resultType = transform::AnyOpType::get(b.getContext());
2127 SmallVector<int64_t> staticPadToMultipleOf;
2128 SmallVector<Value> dynamicPadToMultipleOf;
2129 dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
2130 staticPadToMultipleOf);
2131 return build(/*odsBuilder=*/b,
2132 /*result=*/result,
2133 /*types=*/TypeRange{resultType, resultType},
2134 /*target=*/target,
2135 /*padding_values=*/ArrayAttr(), // let inference handle this
2136 /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
2137 /*pad_to_multiple_of=*/dynamicPadToMultipleOf,
2138 /*padToMultipleOf=*/staticPadToMultipleOf,
2139 /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
2140 /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
2141 /*copy_back_op=*/copyBackOp,
2142 /*use_prescribed_tensor_shapes=*/usePrescribedTensorShapes);
2143}
2144
2145void PadOp::getEffects(
2147 consumesHandle(getTargetMutable(), effects);
2148 onlyReadsHandle(getPadToMultipleOfMutable(), effects);
2149 producesHandle(getOperation()->getOpResults(), effects);
2150 modifiesPayload(effects);
2151}
2152
2153SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
2154 Builder b(getContext());
2155 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
2156}
2157
2158DiagnosedSilenceableFailure
2159transform::PadOp::apply(transform::TransformRewriter &rewriter,
2160 transform::TransformResults &results,
2161 transform::TransformState &state) {
2162 auto transformOp = cast<TransformOpInterface>(getOperation());
2163 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
2164
2165 for (Operation *target : state.getPayloadOps(getTarget())) {
2166 auto linalgTarget = dyn_cast<LinalgOp>(target);
2167 if (!linalgTarget) {
2168 auto diag = emitSilenceableError() << "expected LinalgOp target";
2169 diag.attachNote(target->getLoc()) << "target op";
2170 return diag;
2171 }
2172
2173 // Convert the integer packing flags to booleans.
2174 SmallVector<bool> nofoldFlags;
2175 for (int64_t packPadding :
2176 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
2177 nofoldFlags.push_back(static_cast<bool>(packPadding));
2178
2179 // Convert the padding values to attributes.
2180 SmallVector<Attribute> paddingValues;
2181 for (auto const &[untypedAttr, elementOrTensorType] :
2182 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
2183
2184 if (matchPattern(untypedAttr, ub::m_Poison())) {
2185 paddingValues.push_back(untypedAttr);
2186 continue;
2187 }
2188 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2189 if (!attr) {
2190 emitOpError("expects padding values to be typed attributes or poison");
2192 }
2193 Type elementType = getElementTypeOrSelf(elementOrTensorType);
2194 // Try to parse string attributes to obtain an attribute of element type.
2195 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2196 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2197 stringAttr, getContext(), elementType,
2198 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2199 if (!parsedAttr || parsedAttr.getType() != elementType) {
2200 auto diag = this->emitOpError("expects a padding that parses to ")
2201 << elementType << ", got " << untypedAttr;
2202 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2204 }
2205 paddingValues.push_back(parsedAttr);
2206 continue;
2207 }
2208 // Otherwise, add the attribute directly.
2209 if (attr.getType() != elementType) {
2210 auto diag = this->emitOpError("expects a padding value of type ")
2211 << elementType << ", got " << attr;
2212 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2214 }
2215 paddingValues.push_back(attr);
2216 }
2217
2218 // Extract the transpose vectors.
2219 SmallVector<SmallVector<int64_t>> transposePaddings;
2220 for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2221 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2222 cast<ArrayAttr>(transposeVector)));
2223
2224 LinalgOp paddedOp;
2225 LinalgPaddingOptions options;
2226 options.paddingDimensions =
2227 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2228
2229 SmallVector<int64_t> padToMultipleOf;
2230 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
2231 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2232 if (!status.succeeded())
2233 return status;
2234 if (padToMultipleOf.empty())
2235 padToMultipleOf =
2236 SmallVector<int64_t>(options.paddingDimensions.size(), 1);
2237
2238 options.padToMultipleOf = padToMultipleOf;
2239 options.paddingValues = paddingValues;
2240 options.nofoldFlags = nofoldFlags;
2241 if (getCopyBackOp() ==
2242 bufferization::MaterializeInDestinationOp::getOperationName()) {
2243 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
2244 BufferizationMaterializeInDestination;
2245 } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2246 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
2247 } else if (getCopyBackOp() == kCopyOpNone) {
2248 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
2249 } else {
2250 llvm_unreachable("unsupported copy_back op");
2251 }
2252 // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2253 bool irChanged = false;
2254 if (getUsePrescribedTensorShapes() &&
2255 linalgTarget.hasPureTensorSemantics()) {
2256 OpBuilder::InsertionGuard g(rewriter);
2257 rewriter.setInsertionPoint(linalgTarget);
2258 for (OpOperand &operand : linalgTarget->getOpOperands()) {
2259 for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2260 if (ShapedType::isStatic(dim))
2261 continue;
2262 options.setSizeToPadTo(operand.getOperandNumber(), i,
2263 tensor::getMixedSize(rewriter,
2264 operand.get().getLoc(),
2265 operand.get(), i));
2266 irChanged = true;
2267 }
2268 }
2269 }
2270
2271 SmallVector<Value> replacements;
2272 SmallVector<tensor::PadOp> newPadOps;
2273 if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
2274 replacements, newPadOps))) {
2275 if (irChanged) {
2276 auto diag = emitDefiniteFailure() << "failed to pad op";
2277 diag.attachNote(target->getLoc()) << "target op";
2278 return diag;
2279 }
2280 auto diag = emitSilenceableError() << "failed to pad op";
2281 diag.attachNote(target->getLoc()) << "target op";
2282 return diag;
2283 }
2284
2285 // We need to perform our own replacement here because this API is still
2286 // used in patterns that "pad and hoist", for which the replacement values
2287 // need to be different.
2288 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2289 // that we have more composable abstractions.
2290 rewriter.replaceOp(linalgTarget, replacements);
2291 paddedOps.push_back(paddedOp);
2292 padOps.append(newPadOps.begin(), newPadOps.end());
2293 if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2294 for (Value v : replacements) {
2295 Operation *copyBackOp = v.getDefiningOp();
2296 if (!llvm::is_contained(copyBackOps, copyBackOp))
2297 copyBackOps.push_back(copyBackOp);
2298 }
2299 }
2300 }
2301
2302 results.set(cast<OpResult>(getPadded()), paddedOps);
2303 results.set(cast<OpResult>(getPad()), padOps);
2304 results.set(cast<OpResult>(getCopy()), copyBackOps);
2306}
2307
2308LogicalResult transform::PadOp::verify() {
2309 SmallVector<int64_t> nofoldFlags =
2310 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2311 if (any_of(nofoldFlags, [](int64_t packPadding) {
2312 return packPadding != 0 && packPadding != 1;
2313 })) {
2314 return emitOpError()
2315 << "expects nofold_flags to contain booleans (0/1), found "
2316 << getNofoldFlags();
2317 }
2318
2319 SmallVector<int64_t> paddingDimensions =
2320 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2321 if (any_of(paddingDimensions,
2322 [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2323 return emitOpError() << "expects padding_dimensions to contain positive "
2324 "integers, found "
2325 << getPaddingDimensions();
2326 }
2327 if (!getMixedPadToMultipleOf().empty()) {
2328 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2329 return emitOpError() << "expects as many multiples as padding_dimensions";
2330 }
2331 }
2332 ArrayAttr transposes = getTransposePaddings();
2333 for (Attribute attr : transposes) {
2334 SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2335 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2336 if (!std::is_permutation(sequence.begin(), sequence.end(),
2337 transpose.begin(), transpose.end())) {
2338 return emitOpError()
2339 << "expects transpose_paddings to be a permutation, found "
2340 << attr;
2341 }
2342 }
2343 if (getCopyBackOp() !=
2344 bufferization::MaterializeInDestinationOp::getOperationName() &&
2345 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2346 getCopyBackOp() != kCopyOpNone)
2347 return emitOpError() << "invalid copy_back_op";
2348 return success();
2349}
2350
2351//===---------------------------------------------------------------------===//
2352// PadTilingInterfaceOp
2353//===---------------------------------------------------------------------===//
2354
2355void transform::PadTilingInterfaceOp::build(OpBuilder &b,
2356 OperationState &result,
2357 Value target,
2358 ArrayRef<int64_t> paddingSizes,
2359 bool padToMultipleOf) {
2360 auto resultType = transform::AnyOpType::get(b.getContext());
2361 return build(/*odsBuilder=*/b,
2362 /*result=*/result,
2363 /*types=*/TypeRange{resultType, resultType},
2364 /*target=*/target,
2365 /*padding_values=*/ArrayAttr(), // let inference handle this
2366 /*padding_sizes=*/ValueRange{},
2367 /*paddingSizes=*/
2368 (paddingSizes.empty() ? DenseI64ArrayAttr()
2369 : b.getDenseI64ArrayAttr(paddingSizes)),
2370 /*pad_to_multiple_of=*/
2371 padToMultipleOf ? b.getUnitAttr() : nullptr);
2372}
2373
2374void transform::PadTilingInterfaceOp::build(
2375 OpBuilder &b, OperationState &result, Value target,
2376 ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
2377 auto resultType = transform::AnyOpType::get(b.getContext());
2378 SmallVector<int64_t> staticPaddingSizes;
2379 SmallVector<Value> dynamicPaddingSizes;
2380 dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
2381 staticPaddingSizes);
2382 return build(/*odsBuilder=*/b,
2383 /*result=*/result,
2384 /*types=*/TypeRange{resultType, resultType},
2385 /*target=*/target,
2386 /*padding_values=*/ArrayAttr(), // let inference handle this
2387 /*padding_sizes=*/dynamicPaddingSizes,
2388 /*paddingSizes=*/staticPaddingSizes,
2389 /*usePrescribedTensorShapes=*/padToMultipleOf);
2390}
2391
2392void transform::PadTilingInterfaceOp::getEffects(
2393 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2394 consumesHandle(getTargetMutable(), effects);
2395 onlyReadsHandle(getPaddingSizesMutable(), effects);
2396 producesHandle(getOperation()->getOpResults(), effects);
2397 modifiesPayload(effects);
2398}
2399
2400SmallVector<OpFoldResult>
2401transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2402 Builder b(getContext());
2403 return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2404}
2405
2406DiagnosedSilenceableFailure
2407transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2408 transform::TransformResults &results,
2409 transform::TransformState &state) {
2410 SmallVector<Operation *> paddedOps, padOps;
2411
2412 for (Operation *target : state.getPayloadOps(getTarget())) {
2413 auto targetOp = dyn_cast<TilingInterface>(target);
2414 if (!targetOp) {
2415 auto diag = emitSilenceableError() << "expected TilingInterface target";
2416 diag.attachNote(target->getLoc()) << "target op";
2417 return diag;
2418 }
2419
2420 // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
2421 // loopsToOperand map / C++ APIs to compute the effect of padding on
2422 // operands.
2423 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2424 auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
2425 "supported atm";
2426 diag.attachNote(target->getLoc()) << "target op";
2427 return diag;
2428 }
2429
2430 // Convert the padding values to attributes.
2431 SmallVector<Attribute> paddingValues;
2432 for (auto const &[untypedAttr, elementOrTensorType] :
2433 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2434 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2435 Type elementType = getElementTypeOrSelf(elementOrTensorType);
2436
2437 if (matchPattern(untypedAttr, ub::m_Poison())) {
2438 paddingValues.push_back(untypedAttr);
2439 continue;
2440 }
2441 if (!attr) {
2442 emitOpError("expects padding values to be typed attributes or poison");
2444 }
2445 // Try to parse string attributes to obtain an attribute of element type.
2446 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2447 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2448 stringAttr, getContext(), elementType,
2449 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2450 if (!parsedAttr || parsedAttr.getType() != elementType) {
2451 auto diag = this->emitOpError("expects a padding that parses to ")
2452 << elementType << ", got " << attr;
2453 diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2455 }
2456 paddingValues.push_back(parsedAttr);
2457 continue;
2458 }
2459 // Otherwise, add the attribute directly.
2460 if (attr.getType() != elementType) {
2461 auto diag = this->emitOpError("expects a padding value of type ")
2462 << elementType << ", got " << attr;
2463 diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2465 }
2466 paddingValues.push_back(attr);
2467 }
2468
2469 // Set options.
2470 PadTilingInterfaceOptions options;
2471 options.setPaddingValues(paddingValues)
2472 .setPaddingSizes(getMixedPaddingSizes())
2473 .setPadToMultipleOf(getPadToMultipleOf());
2474
2475 OpBuilder::InsertionGuard g(rewriter);
2476 rewriter.setInsertionPointAfter(targetOp);
2477 auto maybePadOps = rewriteAsPaddedOp(
2478 rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
2479 if (failed(maybePadOps)) {
2480 auto diag = emitSilenceableError() << "failed to pad op";
2481 diag.attachNote(target->getLoc()) << "target op";
2482 return diag;
2483 }
2484 const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
2485
2486 // Set transform results.
2487 paddedOps.push_back(paddedOp);
2488 padOps.append(paddedOperands.begin(), paddedOperands.end());
2489 rewriter.replaceOp(targetOp.getOperation(), slicedResults);
2490 }
2491
2492 results.set(cast<OpResult>(getPadded()), paddedOps);
2493 results.set(cast<OpResult>(getPad()), padOps);
2495}
2496
2497LogicalResult transform::PadTilingInterfaceOp::verify() { return success(); }
2498
2499//===---------------------------------------------------------------------===//
2500// HoistPadOp
2501//===---------------------------------------------------------------------===//
2502
2503DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2504 transform::TransformRewriter &rewriter,
2505 transform::TransformResults &transformResults,
2506 transform::TransformState &state) {
2507 auto targetOps = state.getPayloadOps(getTarget());
2508 auto loopOps = state.getPayloadOps(getLoop());
2509 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2510 return emitDefiniteFailure()
2511 << "requires exactly one target and one loop handle (got "
2512 << llvm::range_size(targetOps) << " and "
2513 << llvm::range_size(loopOps) << ")";
2514 }
2515
2516 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2517 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2518 if (!padOp || !loopOp)
2519 return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2520
2521 FailureOr<linalg::detail::PackingResult> result =
2522 linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2523 getTranspose());
2524 if (failed(result))
2525 return emitDefiniteFailure() << "could not build packing loop nest";
2526
2527 if (result->clonedLoopIvs.empty()) {
2528 transformResults.set(cast<OpResult>(getPackingLoop()),
2529 {result->hoistedPadOp.getOperation()});
2531 }
2532 auto outerPackedLoop =
2533 scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2534 transformResults.set(cast<OpResult>(getPackingLoop()),
2535 {outerPackedLoop.getOperation()});
2537}
2538
2539LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
2540 ArrayRef<int64_t> transpose = getTranspose();
2541 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2542 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2543 transpose.end())) {
2544 return emitOpError() << "expects transpose to be a permutation, found "
2545 << getTranspose();
2546 }
2547 return success();
2548}
2549
2550void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2551 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2552 transform::onlyReadsHandle(getTargetMutable(), effects);
2553 transform::onlyReadsHandle(getLoopMutable(), effects);
2554 transform::producesHandle(getOperation()->getOpResults(), effects);
2556}
2557
2558DiagnosedSilenceableFailure
2559transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2560 tensor::PadOp target,
2561 transform::ApplyToEachResultList &results,
2562 transform::TransformState &state) {
2563 tensor::PadOp hoistedPadOp;
2564 SmallVector<TransposeOp> transposeOps;
2565 FailureOr<Value> result =
2566 hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2567 hoistedPadOp, transposeOps);
2568 if (succeeded(result)) {
2569 // We need to perform our own replacement here because this API is still
2570 // used in patterns that "pad and hoist", for which the replacement values
2571 // need to be different.
2572 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2573 // that we have more composable abstractions.
2574 rewriter.replaceOp(target, *result);
2575 results.push_back(hoistedPadOp);
2577 }
2578 return emitDefaultSilenceableFailure(target);
2579}
2580
2581LogicalResult transform::HoistPadOp::verify() {
2582 ArrayRef<int64_t> transpose = getTranspose();
2583 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2584 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2585 transpose.end())) {
2586 return emitOpError() << "expects transpose to be a permutation, found "
2587 << getTranspose();
2588 }
2589 return success();
2590}
2591
2592//===----------------------------------------------------------------------===//
2593// PromoteOp
2594//===----------------------------------------------------------------------===//
2595
2596DiagnosedSilenceableFailure
2597transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2598 LinalgOp target,
2599 transform::ApplyToEachResultList &results,
2600 transform::TransformState &state) {
2601 LinalgPromotionOptions promotionOptions;
2602 if (!getOperandsToPromote().empty())
2603 promotionOptions = promotionOptions.setOperandsToPromote(
2604 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2605 if (getUseFullTilesByDefault())
2606 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2607 getUseFullTilesByDefault());
2608 if (getUseOriginalSubviewSize())
2609 promotionOptions =
2610 promotionOptions.setUseOriginalSubviewSize(getUseOriginalSubviewSize());
2611 if (getUseAlloca())
2612 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2613 if (!getUseFullTileBuffers().empty())
2614 promotionOptions = promotionOptions.setUseFullTileBuffers(
2615 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2616 if (getAlignment().has_value())
2617 promotionOptions = promotionOptions.setAlignment(*getAlignment());
2618 if (getMemorySpace().has_value())
2619 promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2620
2621 if (getMapping().has_value()) {
2622 // The mapping should only contain an element
2623 auto mapping = *getMapping();
2624 if (mapping.size() > 1)
2625 return emitDefaultDefiniteFailure(target);
2626
2627 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2628
2629 if (addressSpace.getAddressSpace() ==
2630 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2631 promotionOptions =
2632 promotionOptions
2636 .setUseFullTileBuffers({false, false});
2637 } else if (addressSpace.getAddressSpace() ==
2638 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2639 promotionOptions =
2640 promotionOptions
2644 .setUseFullTileBuffers({false, false});
2645 } else {
2646 return emitDefaultDefiniteFailure(target);
2647 }
2648 }
2649
2650 if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2651 return emitDefaultDefiniteFailure(target);
2652
2653 rewriter.setInsertionPoint(target);
2654 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2655 if (failed(res))
2656 return emitDefaultDefiniteFailure(target);
2657 results.push_back(target);
2659}
2660
2661//===----------------------------------------------------------------------===//
2662// ReplaceOp
2663//===----------------------------------------------------------------------===//
2664
2665DiagnosedSilenceableFailure
2666transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2667 TransformResults &transformResults,
2668 TransformState &state) {
2669 auto payload = state.getPayloadOps(getTarget());
2670
2671 // Check for invalid targets.
2672 for (Operation *target : payload) {
2673 if (target->getNumOperands() > 0)
2674 return emitDefiniteFailure() << "expected target without operands";
2675 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2676 target->getNumRegions() > 0)
2677 return emitDefiniteFailure()
2678 << "expected target that is isolated from above";
2679 }
2680
2681 // Clone and replace.
2682 Operation *pattern = &getBodyRegion().front().front();
2683 SmallVector<Operation *> replacements;
2684 for (Operation *target : payload) {
2685 if (getOperation()->isAncestor(target))
2686 continue;
2687 rewriter.setInsertionPoint(target);
2688 Operation *replacement = rewriter.clone(*pattern);
2689 rewriter.replaceOp(target, replacement->getResults());
2690 replacements.push_back(replacement);
2691 }
2692 transformResults.set(cast<OpResult>(getReplacement()), replacements);
2694}
2695
2696void transform::ReplaceOp::getEffects(
2697 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2698 consumesHandle(getTargetMutable(), effects);
2699 producesHandle(getOperation()->getOpResults(), effects);
2700 modifiesPayload(effects);
2701}
2702
2703LogicalResult transform::ReplaceOp::verify() {
2704 if (!getBodyRegion().hasOneBlock())
2705 return emitOpError() << "expected one block";
2706 if (std::distance(getBodyRegion().front().begin(),
2707 getBodyRegion().front().end()) != 1)
2708 return emitOpError() << "expected one operation in block";
2709 Operation *replacement = &getBodyRegion().front().front();
2710 if (replacement->getNumOperands() > 0)
2711 return replacement->emitOpError()
2712 << "expected replacement without operands";
2713 if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2714 replacement->getNumRegions() > 0)
2715 return replacement->emitOpError()
2716 << "expect op that is isolated from above";
2717 return success();
2718}
2719
2720//===----------------------------------------------------------------------===//
2721// ScalarizeOp
2722//===----------------------------------------------------------------------===//
2723
2724DiagnosedSilenceableFailure
2725transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2726 LinalgOp target,
2727 transform::ApplyToEachResultList &results,
2728 transform::TransformState &state) {
2729 scf::SCFTilingOptions tilingOptions;
2730 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2731 SmallVector<OpFoldResult> tileSizes;
2732 Location loc = target.getLoc();
2733 SmallVector<OpFoldResult> allShapeSizes =
2734 target.createFlatListOfOperandDims(b, loc);
2735 AffineMap map = target.getShapesToLoopsMap();
2736 if (!map)
2737 return tileSizes;
2738 SmallVector<OpFoldResult> shapeSizes =
2740 allShapeSizes);
2741 // If the shape size is dynamic, tile by 1.
2742 // Otherwise, do not tile (i.e. tile size 0).
2743 for (OpFoldResult shapeSize : shapeSizes) {
2744 tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2745 : b.getIndexAttr(1));
2746 }
2747 return tileSizes;
2748 });
2749 rewriter.setInsertionPoint(target);
2750 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2751 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2752 if (failed(maybeTilingResult))
2753 return emitDefaultDefiniteFailure(target);
2754
2755 if (target->getNumResults())
2756 rewriter.replaceOp(target, maybeTilingResult->replacements);
2757 else
2758 rewriter.eraseOp(target);
2759
2760 results.reserve(maybeTilingResult->tiledOps.size());
2761 for (Operation *tiled : maybeTilingResult->tiledOps)
2762 results.push_back(tiled);
2764}
2765
2766//===----------------------------------------------------------------------===//
2767// ConvertToLoopsOp
2768//===----------------------------------------------------------------------===//
2769
2770DiagnosedSilenceableFailure
2771transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2772 transform::TransformResults &results,
2773 transform::TransformState &state) {
2774 SmallVector<Operation *> loops;
2775 for (Operation *target : state.getPayloadOps(getTarget())) {
2776 auto tilingOp = dyn_cast<TilingInterface>(*target);
2777 if (!tilingOp) {
2778 DiagnosedSilenceableFailure diag =
2779 emitSilenceableError()
2780 << "expected the payload to implement TilingInterface";
2781 diag.attachNote(target->getLoc()) << "payload op";
2782 return diag;
2783 }
2784 rewriter.setInsertionPoint(target);
2785 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2786 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2787 if (failed(generatedLoops))
2788 return emitDefaultDefiniteFailure(target);
2789 for (scf::ForOp &loop : *generatedLoops) {
2790 loops.push_back(loop.getOperation());
2791 }
2792 rewriter.eraseOp(target);
2793 }
2794 results.set(cast<OpResult>(getResult()), loops);
2796}
2797
2798//===----------------------------------------------------------------------===//
2799// RewriteInDestinationPassingStyleOp
2800//===----------------------------------------------------------------------===//
2801
2802DiagnosedSilenceableFailure
2803transform::RewriteInDestinationPassingStyleOp::applyToOne(
2804 transform::TransformRewriter &rewriter, Operation *target,
2805 transform::ApplyToEachResultList &results,
2806 transform::TransformState &state) {
2807 rewriter.setInsertionPoint(target);
2808 FailureOr<Operation *> maybeResult =
2810 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2811 [&rewriter](auto op) {
2812 return rewriteInDestinationPassingStyle(rewriter, op);
2813 });
2814 if (failed(maybeResult))
2815 return emitDefaultSilenceableFailure(target);
2816 results.push_back(*maybeResult);
2818}
2819
2820//===----------------------------------------------------------------------===//
2821// SplitOp
2822//===----------------------------------------------------------------------===//
2823
2824DiagnosedSilenceableFailure
2825SplitOp::apply(transform::TransformRewriter &rewriter,
2826 TransformResults &results, TransformState &state) {
2827 // Collect the dynamic split points if provided.
2828 SmallVector<Operation *> payload =
2829 llvm::to_vector(state.getPayloadOps(getTarget()));
2830
2831 bool isMultiwaySplit = getMultiway();
2832
2833 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2834 return mlir::emitSilenceableFailure(getLoc())
2835 << "requires exactly one target when "
2836 "multiway split is enabled (got "
2837 << llvm::range_size(payload) << ")";
2838 }
2839
2840 SmallVector<OpFoldResult> chunkSizes;
2841
2842 if (!isMultiwaySplit)
2843 chunkSizes.reserve(payload.size());
2844
2845 if (getDynamicChunkSizes()) {
2847 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2848 chunkSizes = llvm::map_to_vector(
2849 state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2850 if (op->getNumResults() != 1 ||
2851 !op->getResult(0).getType().isIndex()) {
2852 diag = emitSilenceableError()
2853 << "expected dynamic split point handle to point to a "
2854 "single-result index-typed op";
2855 diag.attachNote(op->getLoc()) << "dynamic split point";
2856 }
2857 return OpFoldResult(op->getResult(0));
2858 });
2859 } else {
2860 chunkSizes = llvm::map_to_vector(
2861 state.getParams(getDynamicChunkSizes()),
2862 [](Attribute attr) { return OpFoldResult(attr); });
2863 }
2864 if (diag.isSilenceableFailure())
2865 return diag;
2866
2867 // For multiway split, a single payload is expected to have multiple
2868 // split points.
2869 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2870 return emitDefiniteFailure()
2871 << "expected the dynamic split point handle to point to as "
2872 "many operations ("
2873 << chunkSizes.size() << ") as the target handle ("
2874 << payload.size() << ")";
2875 }
2876 } else {
2877 chunkSizes.resize(payload.size(),
2878 rewriter.getIndexAttr(getStaticChunkSizes()));
2879 }
2880
2881 auto checkStructuredOpAndDimensions =
2882 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2883 if (!linalgOp) {
2884 auto diag = emitSilenceableError() << "only applies to structured ops";
2885 diag.attachNote(loc) << "target op";
2886 return diag;
2887 }
2888
2889 if (getDimension() >= linalgOp.getNumLoops()) {
2890 auto diag = emitSilenceableError() << "dimension " << getDimension()
2891 << " does not exist in target op";
2892 diag.attachNote(loc) << "target op";
2893 return diag;
2894 }
2896 };
2897
2898 auto checkFailureInSplitting =
2899 [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2900 if (hasFailed) {
2901 auto diag = emitDefiniteFailure() << "internal failure in splitting";
2902 diag.attachNote(loc) << "target op";
2903 return diag;
2904 }
2906 };
2907
2908 SmallVector<Operation *> opList;
2909 if (isMultiwaySplit) {
2910
2911 // Split a single target operation at multiple points.
2912 TilingInterface head, tail;
2913 Operation *target = payload.front();
2914
2915 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2916
2917 // Check that the target is a valid LinalgOp with correct dimensions.
2918 DiagnosedSilenceableFailure diag =
2919 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2920 if (diag.isSilenceableFailure())
2921 return diag;
2922
2923 for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2924
2925 if (idx > 0)
2926 target = tail.getOperation();
2927
2928 if (!target)
2929 break;
2930
2931 linalgOp = cast<LinalgOp>(target);
2932 Location loc = target->getLoc();
2933
2934 rewriter.setInsertionPoint(linalgOp);
2935 std::tie(head, tail) = linalg::splitOp(
2936 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2937 getDimension(), chunkSize);
2938
2939 // Propagate errors.
2940 DiagnosedSilenceableFailure diag =
2941 checkFailureInSplitting(!head && !tail, loc);
2942 if (diag.isDefiniteFailure())
2943 return diag;
2944
2945 opList.push_back(head.getOperation());
2946 }
2947
2948 // Append any leftover parts to the end of the result list.
2949 if (tail)
2950 opList.push_back(tail.getOperation());
2951
2952 } else {
2953 // Split each target operation.
2954 SmallVector<Operation *> first, second;
2955 Operation *noSecondPart = nullptr;
2956 for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2957 Operation *target = std::get<0>(pair);
2958 Location loc = target->getLoc();
2959 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2960 DiagnosedSilenceableFailure diag =
2961 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2962
2963 if (diag.isSilenceableFailure())
2964 return diag;
2965
2966 rewriter.setInsertionPoint(linalgOp);
2967 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2968 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2969 getDimension(), std::get<1>(pair));
2970
2971 // Propagate errors.
2972 DiagnosedSilenceableFailure diagSplit =
2973 checkFailureInSplitting(!first.back() && !second.back(), loc);
2974 if (diagSplit.isDefiniteFailure())
2975 return diag;
2976
2977 // Do not add null second parts.
2978 if (!second.back()) {
2979 noSecondPart = target;
2980 second.pop_back();
2981 }
2982 }
2983
2984 if (second.size() != first.size() && !second.empty()) {
2985 auto diag = emitSilenceableError()
2986 << "splitting does not produce the second part for a subset "
2987 "of targets";
2988 diag.attachNote()
2989 << "expected splitting to produce the second part of all "
2990 "or none of the targets";
2991 diag.attachNote(noSecondPart->getLoc())
2992 << "first target with no second part";
2993 return diag;
2994 }
2995
2996 opList.append(first);
2997 if (!second.empty())
2998 opList.append(second);
2999 }
3000 results.set(cast<OpResult>(getSplitList()), opList);
3002}
3003
3004void SplitOp::getEffects(
3005 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3006 consumesHandle(getTargetMutable(), effects);
3007 if (getDynamicChunkSizes())
3008 onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
3009 producesHandle(getOperation()->getOpResults(), effects);
3010 modifiesPayload(effects);
3011}
3012
3013ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
3014 OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
3015 IntegerAttr staticChunkSizes;
3016 if (parser.parseOperand(target) || parser.parseKeyword("after"))
3017 return failure();
3018
3019 OptionalParseResult dynamicPointParseResult =
3020 parser.parseOptionalOperand(dynamicChunkSizes);
3021 if (!dynamicPointParseResult.has_value()) {
3022 int64_t staticChunkSizesValue;
3023 if (failed(parser.parseInteger(staticChunkSizesValue)))
3024 return failure();
3025
3026 staticChunkSizes =
3027 parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
3028 }
3029
3030 Type targetType;
3031 if (parser.parseOptionalAttrDict(result.attributes) ||
3032 parser.parseColonType(targetType) ||
3033 parser.resolveOperand(target, targetType, result.operands)) {
3034 return failure();
3035 }
3036 if (dynamicPointParseResult.has_value()) {
3037 Type chunkSizesType;
3038 if (failed(*dynamicPointParseResult) || parser.parseComma() ||
3039 parser.parseType(chunkSizesType) ||
3040 parser.resolveOperand(dynamicChunkSizes, chunkSizesType,
3041 result.operands)) {
3042 return failure();
3043 }
3044
3045 staticChunkSizes =
3046 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
3047 }
3048
3049 result.addAttribute(
3050 SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
3051 staticChunkSizes);
3052 result.addTypes(targetType);
3053 return success();
3054}
3055
3056void SplitOp::print(OpAsmPrinter &printer) {
3057 printer << " " << getTarget() << " after ";
3058 int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
3059 if (staticChunkSize != ShapedType::kDynamic)
3060 printer << staticChunkSize;
3061 else
3062 printer << getDynamicChunkSizes();
3063 printer << " ";
3064 printer.printOptionalAttrDict(getOperation()->getAttrs(),
3065 {getStaticChunkSizesAttrName()});
3066 printer << " : " << getTarget().getType();
3067 if (staticChunkSize == ShapedType::kDynamic)
3068 printer << ", " << getDynamicChunkSizes().getType();
3069}
3070
3071LogicalResult SplitOp::verify() {
3072 if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
3073 (getDynamicChunkSizes() == nullptr)) {
3074 return emitOpError() << "expects either a dynamic or a static split "
3075 "point to be provided";
3076 }
3077 return success();
3078}
3079
3080//===----------------------------------------------------------------------===//
3081// SplitReductionOp
3082//===----------------------------------------------------------------------===//
3083
3084void transform::SplitReductionOp::build(
3085 OpBuilder &builder, OperationState &result, Value target,
3086 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
3087 bool useScalingAlgorithm, bool useAlloc) {
3088 MLIRContext *ctx = builder.getContext();
3089 result.addOperands(target);
3090 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
3091 builder.getI64IntegerAttr(splitFactor));
3092 result.addAttribute(
3093 SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
3094 builder.getI64IntegerAttr(insertSplitDimension));
3095 if (innerParallel) {
3096 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
3097 builder.getUnitAttr());
3098 }
3099 if (useScalingAlgorithm) {
3100 result.addAttribute(
3101 SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
3102 builder.getUnitAttr());
3103 }
3104 if (useAlloc) {
3105 result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
3106 builder.getUnitAttr());
3107 }
3108 auto resultType = transform::AnyOpType::get(ctx);
3109 result.addTypes({resultType, resultType, resultType, resultType});
3110}
3111
3112DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
3113 transform::TransformRewriter &rewriter, LinalgOp target,
3114 transform::ApplyToEachResultList &results,
3115 transform::TransformState &state) {
3116 ControlSplitReductionFn splitFn = [&](LinalgOp) {
3117 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
3118 unsigned(getInsertSplitDimension()),
3119 bool(getInnerParallel())};
3120 };
3121 rewriter.setInsertionPoint(target);
3122 FailureOr<SplitReductionResult> splitResult =
3123 (getUseScalingAlgorithm())
3124 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
3125 : splitReduction(rewriter, target, splitFn, getUseAlloc());
3126 if (failed(splitResult))
3127 return emitDefaultDefiniteFailure(target);
3128
3129 results.push_back(splitResult->initOrAlloc);
3130 results.push_back(splitResult->fillOp);
3131 results.push_back(splitResult->splitLinalgOp);
3132 results.push_back(splitResult->resultCombiningLinalgOp);
3134}
3135
3136//===----------------------------------------------------------------------===//
3137// TileReductionUsingForOp
3138//===----------------------------------------------------------------------===//
3139
3140void transform::TileReductionUsingForOp::build(
3141 OpBuilder &builder, OperationState &result, Value target,
3142 ArrayRef<int64_t> staticTileSizes) {
3143 // Call the default builder.
3144 // This is future-proof re mixed static-dynamic and setting up the proper
3145 // operands segment sizes attributes for multiple variadic operands.
3146 // In the absence of this, horrible bugs ensue.
3147 // TODO: support mixed static-dynamic (see TileUsingForallOp).
3148 MLIRContext *ctx = builder.getContext();
3149 auto opTy = transform::AnyOpType::get(ctx);
3150 auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
3151 build(builder, result,
3152 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3153 /*target=*/target,
3154 /*reduction_dims=*/nullptr,
3155 /*tile_sizes=*/staticTileSizesAttr);
3156}
3157
3158DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
3159 transform::TransformRewriter &rewriter, Operation *target,
3160 transform::ApplyToEachResultList &results,
3161 transform::TransformState &state) {
3162 rewriter.setInsertionPoint(target);
3163
3164 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3165 if (!partialReductionOp) {
3167 target->getLoc(),
3168 "Operation should implement PartialReductionOpInterface");
3169 }
3170
3171 SmallVector<unsigned> reductionDims =
3172 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3173 if (reductionDims.empty()) {
3174 for (auto [idx, iteratorType] :
3175 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3176 if (iteratorType == utils::IteratorType::reduction)
3177 reductionDims.push_back(idx);
3178 }
3179 }
3180
3181 scf::SCFTilingOptions options;
3182 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
3183 options.setReductionTilingStrategy(
3185 options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
3186 options.setReductionDims(reductionDims);
3187 FailureOr<scf::SCFTilingResult> result =
3188 scf::tileUsingSCF(rewriter, partialReductionOp, options);
3189
3190 if (failed(result)) {
3191 return emitSilenceableFailure(getLoc(),
3192 "failed to tile using partial reduction");
3193 }
3194 rewriter.replaceOp(target, result->replacements);
3195 for (Value initValue : result->initialValues)
3196 results.push_back(initValue.getDefiningOp());
3197 for (auto *parallelTiledOp : result->tiledOps)
3198 results.push_back(parallelTiledOp);
3199 for (auto *mergeOp : result->mergeOps)
3200 results.push_back(mergeOp);
3201 results.push_back(result->loops.front());
3203}
3204
3205//===----------------------------------------------------------------------===//
3206// TileReductionUsingForallOp
3207//===----------------------------------------------------------------------===//
3208
3209void transform::TileReductionUsingForallOp::build(
3210 OpBuilder &builder, OperationState &result, Value target,
3211 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3212 ArrayAttr mapping) {
3213 // Call the default builder.
3214 // This is future-proof re mixed static-dynamic and setting up the proper
3215 // operands segment sizes attributes for multiple variadic operands.
3216 // In the absence of this, horrible bugs ensue.
3217 // TODO: support mixed static-dynamic (see TileUsingForallOp).
3218 MLIRContext *ctx = builder.getContext();
3219 auto opTy = transform::AnyOpType::get(ctx);
3220 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3221 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3222 build(builder, result,
3223 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3224 /*target=*/target,
3225 /*reduction_dims=*/{},
3226 /*num_threads=*/staticNumThreadsAttr,
3227 /*tile_sizes=*/staticTileSizesAttr,
3228 /*mapping=*/mapping);
3229}
3230
3231DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3232 transform::TransformRewriter &rewriter, Operation *target,
3233 transform::ApplyToEachResultList &results,
3234 transform::TransformState &state) {
3235 rewriter.setInsertionPoint(target);
3236
3237 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3238 if (!partialReductionOp) {
3240 target->getLoc(),
3241 "Operation should implement PartialReductionOpInterface");
3242 }
3243 SmallVector<OpFoldResult> numThreads =
3244 getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
3245 SmallVector<OpFoldResult> tileSizes =
3247
3248 scf::SCFTilingOptions options;
3249 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3250 options.setReductionTilingStrategy(
3252 if (!getNumThreads().empty()) {
3253 options.setNumThreads(numThreads);
3254 } else {
3255 options.setTileSizes(tileSizes);
3256 }
3257 if (auto mapping = getMapping()) {
3258 options.setMapping(mapping.value().getValue());
3259 }
3260 SmallVector<unsigned> reductionDims =
3261 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3262 if (reductionDims.empty()) {
3263 for (auto [idx, iteratorType] :
3264 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3265 if (iteratorType == utils::IteratorType::reduction)
3266 reductionDims.push_back(idx);
3267 }
3268 }
3269 options.setReductionDims(reductionDims);
3270 FailureOr<scf::SCFTilingResult> result =
3271 scf::tileUsingSCF(rewriter, partialReductionOp, options);
3272
3273 if (failed(result)) {
3274 auto diag = emitSilenceableError() << "could not tile reduction";
3275 return diag;
3276 }
3277 rewriter.replaceOp(target, result->replacements);
3278
3279 for (Value initValue : result->initialValues)
3280 results.push_back(initValue.getDefiningOp());
3281 for (auto *parallelTiledOp : result->tiledOps)
3282 results.push_back(parallelTiledOp);
3283 for (auto *mergeOp : result->mergeOps)
3284 results.push_back(mergeOp);
3285 results.push_back(result->loops.front());
3287}
3288
3289//===----------------------------------------------------------------------===//
3290// ContinuousTileSizesOp
3291//===----------------------------------------------------------------------===//
3292
3293DiagnosedSilenceableFailure
3294transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3295 TransformResults &transformResults,
3296 TransformState &state) {
3297
3298 SmallVector<Operation *> targetOps =
3299 llvm::to_vector(state.getPayloadOps(getTarget()));
3300
3301 if (!llvm::hasSingleElement(targetOps)) {
3302 return mlir::emitSilenceableFailure(getLoc())
3303 << "requires exactly one target (got " << llvm::range_size(targetOps)
3304 << ")";
3305 }
3306
3307 Operation *target = *targetOps.begin();
3308 auto linalgOp = dyn_cast<LinalgOp>(target);
3309 auto tileableOp = dyn_cast<TilingInterface>(target);
3310
3311 if (!linalgOp)
3312 return emitDefiniteFailure() << "expected Linalg Op";
3313
3314 OpBuilder builder(linalgOp.getContext());
3315
3316 if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
3317 if (linalgOp.hasDynamicShape()) {
3318 auto diag = emitSilenceableError()
3319 << "cannot compute parametric tile sizes for dynamically "
3320 "shaped payload op";
3321 diag.attachNote(linalgOp->getLoc()) << "payload op";
3322 return diag;
3323 }
3324
3325 FailureOr<StaticContinuousTileSizeSpecification> spec =
3326 computeStaticContinuousTileSizes(linalgOp, getDimension(),
3327 getTargetSize());
3328 if (failed(spec)) {
3329 return emitSilenceableError()
3330 << "failed to compute multi-size tiling sizes";
3331 }
3332
3333 SmallVector<int64_t> chunkSizes;
3334
3335 for (auto &&[tileSize, tripCount] :
3336 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3337 chunkSizes.push_back(tileSize * tripCount);
3338
3339 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3340 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3341 return builder.getI64IntegerAttr(value);
3342 });
3343 };
3344 transformResults.setParams(cast<OpResult>(getTileSizes()),
3345 getI64AttrsFromI64(spec->tileSizes));
3346 transformResults.setParams(cast<OpResult>(getChunkSizes()),
3347 getI64AttrsFromI64(chunkSizes));
3348
3350 }
3351
3352 builder.setInsertionPoint(linalgOp);
3353
3354 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
3355 unsigned dimension = getDimension();
3356
3357 FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
3358 builder, tileableOp, dimension, targetSize, true);
3359 if (failed(spec)) {
3360 return emitSilenceableError() << "could not generate tile size computation";
3361 }
3362
3363 AffineExpr s0 = builder.getAffineSymbolExpr(0);
3364 AffineExpr s1 = builder.getAffineSymbolExpr(1);
3365 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3366 return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
3367 ofrs);
3368 };
3369
3370 SmallVector<Value> chunkSizes;
3371 Value splitPoint;
3372 for (auto &&[tileSize, tripCount] :
3373 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3374 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3375 chunkSizes.push_back(splitPoint);
3376 }
3377
3378 auto getDefiningOps = [&](ArrayRef<Value> values) {
3379 return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3380 return value.getDefiningOp();
3381 });
3382 };
3383
3384 transformResults.set(cast<OpResult>(getTileSizes()),
3385 getDefiningOps(spec->tileSizes));
3386 transformResults.set(cast<OpResult>(getChunkSizes()),
3387 getDefiningOps(chunkSizes));
3388
3390}
3391
3392LogicalResult transform::ContinuousTileSizesOp::verify() {
3393
3394 if (getTileSizes().getType() != getChunkSizes().getType()) {
3395 return emitOpError() << "expects all results type to be the same";
3396 }
3397
3398 return success();
3399}
3400
3401void transform::ContinuousTileSizesOp::getEffects(
3402 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3403 if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
3404 onlyReadsPayload(effects);
3405 else
3406 modifiesPayload(effects);
3407 onlyReadsHandle(getTargetMutable(), effects);
3408 producesHandle(getOperation()->getOpResults(), effects);
3409}
3410
3412 Type targetType, Type tileSizes,
3413 Type) {
3414 printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes});
3415}
3416
3418 Type &targetType,
3419 Type &tileSizesType,
3420 Type &chunkSizesType) {
3421 FunctionType funcType;
3422 llvm::SMLoc typeLoc = parser.getCurrentLocation();
3423 if (failed(parser.parseType<FunctionType>(funcType)))
3424 return failure();
3425
3426 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3427 parser.emitError(typeLoc) << "expects a trailing functional type with one "
3428 "argument and one result";
3429 }
3430 targetType = funcType.getInput(0);
3431 tileSizesType = chunkSizesType = funcType.getResult(0);
3432
3433 return success();
3434}
3435
3436//===----------------------------------------------------------------------===//
3437// TileUsingForOp
3438//===----------------------------------------------------------------------===//
3439
3440void transform::TileUsingForOp::build(
3441 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3442 Value target, ArrayRef<int64_t> staticTileSizes,
3443 ArrayRef<int64_t> interchange,
3444 std::optional<ArrayRef<bool>> scalableSizes) {
3445 return build(builder, result, loopTypes,
3446 /*target=*/target,
3447 /*mixedTileSizes=*/
3448 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3449 interchange, scalableSizes);
3450}
3451
3452void transform::TileUsingForOp::build(
3453 OpBuilder &builder, OperationState &result, Value target,
3454 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3455 std::optional<ArrayRef<bool>> scalableSizes) {
3456 build(builder, result, target,
3457 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3458 interchange, scalableSizes);
3459}
3460
3461void transform::TileUsingForOp::build(
3462 OpBuilder &builder, OperationState &result, Value target,
3463 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3464 std::optional<ArrayRef<bool>> scalableSizes) {
3465 // Loop types are automaticaly splat by the callee, setting up one is
3466 // enough.
3467 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
3468 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3469 scalableSizes);
3470}
3471
3472void transform::TileUsingForOp::build(
3473 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3474 Value target, ArrayRef<OpFoldResult> mixedTileSizes,
3475 ArrayRef<int64_t> interchange,
3476 std::optional<ArrayRef<bool>> scalableSizes) {
3477 SmallVector<int64_t> staticTileSizes;
3478 SmallVector<Value> dynamicTileSizes;
3479 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3480 // Call the default builder which sets up the proper operands segment sizes
3481 // attributes for multiple variadic operands. In the absence of this,
3482 // horrible bugs ensue.
3483 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3484 unsigned numExpectedLoops =
3485 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3486 SmallVector<Type> resultTypes;
3487 resultTypes.reserve(numExpectedLoops);
3488 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3489 "expected one loop type or as many as loops");
3490 if (loopTypes.size() == 1)
3491 resultTypes.append(numExpectedLoops, loopTypes[0]);
3492 else
3493 llvm::append_range(resultTypes, loopTypes);
3494 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
3495 if (scalableSizes.has_value())
3496 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3497 build(builder, result, /*tiled_linalg_op=*/target.getType(),
3498 /*loops=*/resultTypes,
3499 /*target=*/target,
3500 /*dynamic_sizes=*/dynamicTileSizes,
3501 /*static_sizes=*/staticTileSizesAttr,
3502 /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
3503 /*scalable_sizes=*/expandedScalableSizes);
3504}
3505
3506LogicalResult transform::TileUsingForOp::verify() {
3507 if (getMixedSizes().size() != getScalableSizes().size())
3508 return emitOpError("expected same number of sizes (")
3509 << getMixedSizes().size() << ") and scalable sizes ("
3510 << getScalableSizes().size() << ")";
3511 ArrayRef<int64_t> staticSizes = getStaticSizes();
3512 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3513 if (getLoops().size() != numExpectedLoops)
3514 return emitOpError("expected number of loops to tile (")
3515 << numExpectedLoops << ") to match number of `loops` results ("
3516 << getLoops().size() << ")";
3517 return success();
3518}
3519
3520DiagnosedSilenceableFailure
3521transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3522 TransformResults &transformResults,
3523 TransformState &state) {
3524 ArrayRef<int64_t> tileSizes = getStaticSizes();
3525
3526 SmallVector<Operation *> targets =
3527 llvm::to_vector(state.getPayloadOps(getTarget()));
3528 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3529 SmallVector<SmallVector<int64_t>> paramSizes;
3530 dynamicSizeProducers.reserve(getDynamicSizes().size());
3531 paramSizes.reserve(getDynamicSizes().size());
3532 for (Value transformValue : getDynamicSizes()) {
3533 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
3534 dynamicSizeProducers.push_back({});
3535 ArrayRef<Attribute> params = state.getParams(transformValue);
3536 paramSizes.push_back(llvm::map_to_vector(params, [](Attribute attr) {
3537 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3538 }));
3539
3540 if (paramSizes.back().size() != targets.size()) {
3541 DiagnosedSilenceableFailure diag =
3542 emitSilenceableError()
3543 << "expected as many parameter values ("
3544 << dynamicSizeProducers.back().size() << ") as target ops ("
3545 << targets.size() << ")";
3546 diag.attachNote(transformValue.getLoc()) << "for this parameter";
3547 return diag;
3548 }
3549
3550 continue;
3551 }
3552 paramSizes.push_back({});
3553 dynamicSizeProducers.push_back(
3554 llvm::to_vector(state.getPayloadOps(transformValue)));
3555
3556 if (dynamicSizeProducers.back().size() != targets.size()) {
3557 DiagnosedSilenceableFailure diag =
3558 emitSilenceableError()
3559 << "expected as many dynamic size-producing operations ("
3560 << dynamicSizeProducers.back().size() << ") as target ops ("
3561 << targets.size() << ")";
3562 diag.attachNote(transformValue.getLoc()) << "for this handle";
3563 return diag;
3564 }
3565
3566 for (Operation *op : dynamicSizeProducers.back()) {
3567 if (op->getNumResults() == 1 &&
3568 isa<IndexType>(op->getResult(0).getType())) {
3569 continue;
3570 }
3571
3572 DiagnosedSilenceableFailure diag =
3573 emitSilenceableError() << "expected sizes to be produced by ops "
3574 "with a single index-type result";
3575 diag.attachNote(op->getLoc()) << "size producer op";
3576 diag.attachNote(transformValue.getLoc()) << "for this handle";
3577 return diag;
3578 }
3579 }
3580
3581 SmallVector<Operation *> tiled;
3582 SmallVector<SmallVector<Operation *, 4>, 4> loops;
3583 loops.resize(getLoops().size());
3584 auto scalableSizes = getScalableSizes();
3585 for (auto [i, op] : llvm::enumerate(targets)) {
3586 auto tilingInterface = dyn_cast<TilingInterface>(op);
3587 if (!tilingInterface) {
3588 DiagnosedSilenceableFailure diag =
3589 emitSilenceableError()
3590 << "only ops implementing TilingInterface are supported";
3591 diag.attachNote(op->getLoc()) << "target op";
3592 return diag;
3593 }
3594 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3595 DiagnosedSilenceableFailure diag =
3596 emitSilenceableError()
3597 << "too many tiles provided, expected at most "
3598 << tilingInterface.getLoopIteratorTypes().size() << " found "
3599 << tileSizes.size();
3600 diag.attachNote(op->getLoc()) << "target op";
3601 return diag;
3602 }
3603
3604 scf::SCFTilingOptions tilingOptions;
3605 if (tileSizes.empty()) {
3606 tilingOptions.setTileSizeComputationFunction(
3607 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
3608 return {};
3609 });
3610 } else {
3611 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3612 Operation *) {
3613 SmallVector<OpFoldResult> sizes;
3614 sizes.reserve(tileSizes.size());
3615 unsigned dynamicIdx = 0;
3616
3617 for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3618 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3619 if (scalableSizes[ofrIdx]) {
3621 b, getLoc(), cast<IntegerAttr>(attr).getInt());
3622 Value vscale =
3623 vector::VectorScaleOp::create(b, getLoc(), b.getIndexType());
3624 sizes.push_back(
3625 arith::MulIOp::create(b, getLoc(), val, vscale).getResult());
3626 } else {
3627 sizes.push_back(attr);
3628 }
3629 continue;
3630 }
3631 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3632 ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3633 ++dynamicIdx;
3634 assert((dynamicSizes.empty() ^ params.empty()) &&
3635 "expected either dynamic sizes or parameters");
3636 if (!params.empty()) {
3637 sizes.push_back(b.getIndexAttr(params[index]));
3638 } else {
3639 sizes.push_back(dynamicSizes[index]->getResult(0));
3640 }
3641 }
3642 return sizes;
3643 });
3644 }
3645
3646 tilingOptions.setInterchange(getInterchange());
3647 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3648 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3649 if (failed(maybeTilingResult))
3651
3652 rewriter.replaceOp(op, maybeTilingResult->replacements);
3653
3654 tiled.append(maybeTilingResult->tiledOps);
3655 for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3656 loops[en2.index()].push_back(en2.value());
3657 }
3658
3659 transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3660 for (const auto &en : llvm::enumerate(loops))
3661 transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3662
3664}
3665
3666SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
3667 ValueRange dynamic = getDynamicSizes();
3668 ArrayRef<int64_t> tileSizes = getStaticSizes();
3669 SmallVector<OpFoldResult> results;
3670 results.reserve(tileSizes.size());
3671 unsigned dynamicPos = 0;
3672 Builder builder(getContext());
3673 for (int64_t size : tileSizes) {
3674 if (size == ShapedType::kDynamic) {
3675 results.push_back(dynamic[dynamicPos++]);
3676 } else {
3677 results.push_back(builder.getIndexAttr(size));
3678 }
3679 }
3680 return results;
3681}
3682
3683void transform::TileUsingForOp::getEffects(
3684 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3685 consumesHandle(getTargetMutable(), effects);
3686 onlyReadsHandle(getDynamicSizesMutable(), effects);
3687 producesHandle(getOperation()->getOpResults(), effects);
3688 modifiesPayload(effects);
3689}
3690
3691//===----------------------------------------------------------------------===//
3692// TileUsingForallOp
3693//===----------------------------------------------------------------------===//
3694
3695void transform::TileUsingForallOp::build(OpBuilder &builder,
3696 OperationState &result, Value target,
3697 ArrayRef<int64_t> staticTileSizes,
3698 transform::TileSizesSpec,
3699 ArrayAttr mapping) {
3700 return build(builder, result,
3701 /*target=*/target,
3702 /*mixedTileSizes=*/
3703 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3704 /*_=*/TileSizesSpec(),
3705 /*mapping=*/mapping);
3706}
3707
3708void transform::TileUsingForallOp::build(OpBuilder &builder,
3709 OperationState &result, Value target,
3710 ArrayRef<OpFoldResult> mixedTileSizes,
3711 transform::TileSizesSpec,
3712 ArrayAttr mapping) {
3713 SmallVector<int64_t> staticTileSizes;
3714 SmallVector<Value> dynamicTileSizes;
3715 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3716 // Call the default builder which sets up the proper operands segment sizes
3717 // attributes for multiple variadic operands. In the absence of this,
3718 // horrible bugs ensue.
3719 MLIRContext *ctx = builder.getContext();
3720 auto operationType = transform::AnyOpType::get(ctx);
3721 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3722 build(builder, result,
3723 /*resultTypes=*/TypeRange{operationType, operationType},
3724 /*target=*/target,
3725 /*num_threads=*/ValueRange{},
3726 /*tile_sizes=*/dynamicTileSizes,
3727 /*packed_num_threads=*/Value(),
3728 /*packed_tile_sizes=*/Value(),
3729 /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3730 /*static_tile_sizes=*/staticTileSizesAttr,
3731 /*mapping=*/mapping);
3732}
3733
3734void transform::TileUsingForallOp::build(OpBuilder &builder,
3735 OperationState &result, Value target,
3736 ArrayRef<int64_t> staticNumThreads,
3737 transform::NumThreadsSpec,
3738 ArrayAttr mapping) {
3739 return build(builder, result, target,
3740 getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3741 NumThreadsSpec(), mapping);
3742}
3743
3744void transform::TileUsingForallOp::build(OpBuilder &builder,
3745 OperationState &result, Value target,
3746 ArrayRef<OpFoldResult> mixedNumThreads,
3747 transform::NumThreadsSpec,
3748 ArrayAttr mapping) {
3749 SmallVector<int64_t> staticNumThreads;
3750 SmallVector<Value> dynamicNumThreads;
3751 dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3752 staticNumThreads);
3753 // Call the default builder which sets up the proper operands segment sizes
3754 // attributes for multiple variadic operands. In the absence of this,
3755 // horrible bugs ensue.
3756 MLIRContext *ctx = builder.getContext();
3757 auto operationType = transform::AnyOpType::get(ctx);
3758 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3759 build(builder, result,
3760 /*resultTypes=*/TypeRange{operationType, operationType},
3761 /*target=*/target,
3762 /*num_threads=*/dynamicNumThreads,
3763 /*tile_sizes=*/ValueRange{},
3764 /*packed_num_threads=*/Value(),
3765 /*packed_tile_sizes=*/Value(),
3766 /*static_num_threads=*/staticNumThreadsAttr,
3767 /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3768 /*mapping=*/mapping);
3769}
3770
3771/// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3772/// normalized upper bound.
3773static SmallVector<OpFoldResult>
3776 ArrayRef<OpFoldResult> steps) {
3777 AffineExpr s0, s1, s2;
3778 bindSymbols(rewriter.getContext(), s0, s1, s2);
3779 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3780 SmallVector<OpFoldResult> normalizedUbs;
3781 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3783 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3784 normalizedUbs.push_back(normalizedUb);
3785 }
3786 return normalizedUbs;
3787}
3788
3789/// When a loop is normalized, the uses of the induction variable within the
3790/// loop need to replaced with `original_lb + old_iv * original_step`.
3792 Location loc, ValueRange ivs,
3794 ArrayRef<OpFoldResult> steps) {
3795 AffineExpr s0, s1;
3796 AffineExpr d0;
3797 bindSymbols(rewriter.getContext(), s0, s1);
3798 bindDims(rewriter.getContext(), d0);
3799 AffineExpr denormExpr = s0 + d0 * s1;
3800 SmallVector<Value> denormalizedIvs;
3801
3802 for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3804 rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3805 denormalizedIvs.push_back(
3806 getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3807 }
3808 return denormalizedIvs;
3809}
3810
3811/// Given a `scf.forall` loop return a loop op with the loop bounds
3812/// normalized.
3813/// TODO: Replace this with a general utility to normalize `scf.forall`.
3814/// At the time of writing, this wasnt done since adding this to `scf`
3815/// dialect would disallow using of `affine.apply` operations due
3816/// to cyclic dependencies. To avoid churn in lit tests
3817/// with the change this was added with, defer that to a follow up.
3818static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3819 scf::ForallOp loop) {
3820 SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3821 SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3822 SmallVector<OpFoldResult> steps = loop.getMixedStep();
3823
3824 if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
3825 return loop;
3826 }
3827
3828 Location loc = loop.getLoc();
3829 SmallVector<OpFoldResult> normalizedUbs =
3830 normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3831 SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3832 rewriter.getIndexAttr(0));
3833 SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3834 rewriter.getIndexAttr(1));
3835
3836 auto normalizedForallOp = scf::ForallOp::create(
3837 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3838 loop.getOutputs(), loop.getMapping(),
3839 [](OpBuilder &, Location, ValueRange) {});
3840
3841 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3842 OpBuilder::InsertionGuard g(rewriter);
3843 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3844 rewriter.setInsertionPointToStart(normalizedLoopBlock);
3845
3846 SmallVector<Value> argValues =
3847 denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3848 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3849 normalizedForallOp.getRegionIterArgs().end());
3850 Block *origLoopBlock = loop.getBody();
3851 rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3852
3853 rewriter.replaceOp(loop, normalizedForallOp);
3854 return normalizedForallOp;
3855}
3856
3858 RewriterBase &rewriter, transform::TransformState &state,
3859 TransformOpInterface transformOp, Operation *target,
3860 ArrayRef<OpFoldResult> mixedNumThreads,
3861 ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3862 scf::SCFTilingResult &tilingResult) {
3863 // Transform all targets one by one.
3864 auto tileableOp = dyn_cast<TilingInterface>(target);
3865 if (!tileableOp) {
3867 transformOp.emitSilenceableError()
3868 << "only TilingInterface ops are supported";
3869 diag.attachNote(target->getLoc()) << "target op";
3870 return diag;
3871 }
3872 rewriter.setInsertionPoint(tileableOp);
3873 scf::SCFTilingOptions options;
3874 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3875 if (!mixedNumThreads.empty()) {
3876 options.setNumThreads(mixedNumThreads);
3877 } else {
3878 options.setTileSizes(mixedTileSizes);
3879 }
3880 if (mapping) {
3881 options.setMapping(mapping.value().getValue());
3882 }
3883 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3884 scf::tileUsingSCF(rewriter, tileableOp, options);
3885
3886 if (failed(maybeTilingResult))
3887 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3888
3889 rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3890
3891 tilingResult = *maybeTilingResult;
3892
3893 if (mixedNumThreads.empty()) {
3894 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3895 OpBuilder::InsertionGuard g(rewriter);
3896 rewriter.setInsertionPoint(generatedForallOp);
3897 scf::ForallOp normalizedForallOp =
3898 normalizeForallLoopOp(rewriter, generatedForallOp);
3899 tilingResult.loops.front() = normalizedForallOp;
3900 }
3901
3903}
3904
3905DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3907 transform::TransformResults &transformResults,
3909 auto transformOp = cast<TransformOpInterface>(getOperation());
3910
3911 // Result payload ops.
3913 SmallVector<Operation *> tiledOps;
3914
3915 // Unpack handles.
3916 SmallVector<OpFoldResult> mixedNumThreads;
3918 getPackedNumThreads()
3920 state, transformOp, mixedNumThreads, getPackedNumThreads())
3922 state, transformOp, mixedNumThreads, getMixedNumThreads());
3923 if (!status.succeeded())
3924 return status;
3925 SmallVector<OpFoldResult> mixedTileSizes;
3926 status = getPackedTileSizes()
3928 state, transformOp, mixedTileSizes, getPackedTileSizes())
3930 state, transformOp, mixedTileSizes, getMixedTileSizes());
3931 if (!status.succeeded())
3932 return status;
3933
3934 for (Operation *target : state.getPayloadOps(getTarget())) {
3935 scf::SCFTilingResult tilingResult;
3937 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3938 getMapping(), tilingResult);
3939 if (!diag.succeeded())
3940 return diag;
3941 tileOps.push_back(tilingResult.loops.front());
3942 tiledOps.append(tilingResult.tiledOps);
3943 }
3944
3945 transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3946 transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3947
3949}
3950
3951void transform::TileUsingForallOp::getEffects(
3952 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3953 consumesHandle(getTargetMutable(), effects);
3954 onlyReadsHandle(getTileSizesMutable(), effects);
3955 onlyReadsHandle(getNumThreadsMutable(), effects);
3956 onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3957 onlyReadsHandle(getPackedTileSizesMutable(), effects);
3958 producesHandle(getOperation()->getOpResults(), effects);
3959 modifiesPayload(effects);
3960}
3961
3962SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3963 Builder b(getContext());
3964 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3965}
3966
3967SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3968 Builder b(getContext());
3969 return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3970}
3971
3972LogicalResult TileUsingForallOp::verify() {
3973 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3974 static_cast<int>(getPackedNumThreads() != Value());
3975 if (numThreadsSpec > 1)
3976 return emitOpError(
3977 "num_threads and packed_num_threads are mutually exclusive");
3978 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3979 static_cast<int>(getPackedTileSizes() != Value());
3980 if (tileSizesSpec > 1)
3981 return emitOpError(
3982 "tile_sizes and packed_tile_sizes are mutually exclusive");
3983 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3984 return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3985 "must be specified");
3986 return success();
3987}
3988
3989//===----------------------------------------------------------------------===//
3990// VectorizeChildrenAndApplyPatternsOp
3991//===----------------------------------------------------------------------===//
3992
3993void transform::VectorizeChildrenAndApplyPatternsOp::build(
3994 OpBuilder &builder, OperationState &result, Value target,
3995 bool foldTypeExtensionsIntoContract, bool vectorizePadding,
3996 bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3997 result.addOperands(target);
3998 if (foldTypeExtensionsIntoContract) {
3999 result.addAttribute(
4000 VectorizeChildrenAndApplyPatternsOp::
4001 getFoldTypeExtensionsIntoContractAttrName(result.name),
4002 builder.getUnitAttr());
4003 }
4004 if (vectorizePadding) {
4005 result.addAttribute(
4006 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
4007 result.name),
4008 builder.getUnitAttr());
4009 }
4010 if (vectorizeExtract) {
4011 result.addAttribute(
4012 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
4013 result.name),
4014 builder.getUnitAttr());
4015 }
4016 if (flatten1DDepthwiseConv) {
4017 result.addAttribute(
4018 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
4019 result.name),
4020 builder.getUnitAttr());
4021 }
4022 result.addTypes(transform::AnyOpType::get(builder.getContext()));
4023}
4024
4025namespace {
4026/// This is an helper only to call vectorize via a pattern inside of
4027/// VectorizeChildrenAndApplyPatternsOp::applyToOne.
4028struct VectorizationPattern : public RewritePattern {
4029 explicit VectorizationPattern(MLIRContext *context,
4030 bool vectorizeExtract = false,
4031 bool flattenConv = false)
4032 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
4033 vectorizeNDExtract(vectorizeExtract),
4034 flatten1DDepthwiseConv(flattenConv) {}
4035 LogicalResult matchAndRewrite(Operation *op,
4036 PatternRewriter &rewriter) const override {
4038 return rewriter.notifyMatchFailure(op,
4039 "Unsupported Op, cannot vectorize");
4040 FailureOr<VectorizationResult> vectorResults =
4041 vectorize(rewriter, op, /*inputVectorSizes=*/{},
4042 /*inputScalableVecDims=*/{}, vectorizeNDExtract,
4043 flatten1DDepthwiseConv);
4044 if (failed(vectorResults))
4045 return failure();
4046 rewriter.replaceOp(op, vectorResults->replacements);
4047 return success();
4048 }
4049
4050private:
4051 /// Controls whether to vectorize `tensor.extract` when the input tensor is
4052 /// rank >= 2.
4053 bool vectorizeNDExtract = false;
4054 /// Controls whether to "flatten" the channel dimension when vectorising 1D
4055 /// depthwise convolutions. This should lead to bette vectorization for
4056 /// tensors with a low number of channel dimensions.
4057 bool flatten1DDepthwiseConv = false;
4058};
4059} // namespace
4060
4061DiagnosedSilenceableFailure
4062transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
4063 transform::TransformRewriter &rewriter, Operation *target,
4064 transform::ApplyToEachResultList &results,
4065 transform::TransformState &state) {
4066 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
4067 auto diag = this->emitOpError("requires isolated-from-above targets");
4068 diag.attachNote(target->getLoc()) << "non-isolated target";
4070 }
4071
4072 MLIRContext *ctx = getContext();
4073 RewritePatternSet patterns(ctx);
4074 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
4075 getFlatten_1dDepthwiseConv());
4076
4077 if (!getDisableTransferPermutationMapLoweringPatterns())
4079
4080 if (!getDisableMultiReductionToContractPatterns())
4082
4084
4085 patterns.add<linalg::LinalgCopyVTRForwardingPattern,
4086 linalg::LinalgCopyVTWForwardingPattern>(ctx,
4087 /*benefit=*/2);
4088 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
4089 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
4091
4092 patterns.add<CopyVectorizationPattern>(ctx);
4093
4094 if (getFoldTypeExtensionsIntoContract())
4096
4097 if (getVectorizePadding()) {
4099 // This creates an alternative path for lowering tensor.pad - by
4100 // decomposing it into e.g. linalg.fill.
4102 }
4104
4105 TrackingListener listener(state, *this);
4106 if (failed(
4107 applyPatternsGreedily(target, std::move(patterns),
4108 GreedyRewriteConfig().setListener(&listener))))
4109 return emitDefaultDefiniteFailure(target);
4110
4111 results.push_back(target);
4113}
4114
4115//===----------------------------------------------------------------------===//
4116// VectorizeOp
4117//===----------------------------------------------------------------------===//
4118
4119DiagnosedSilenceableFailure transform::VectorizeOp::apply(
4120 transform::TransformRewriter &rewriter,
4121 mlir::transform::TransformResults &transformResults,
4122 mlir::transform::TransformState &state) {
4123 auto targets = state.getPayloadOps(getTarget());
4124 if (std::empty(targets))
4126 auto transformOp = cast<TransformOpInterface>(getOperation());
4127 SmallVector<int64_t> vectorSizes;
4128 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
4129 state, transformOp, getMixedVectorSizes(), vectorSizes);
4130 if (!status.succeeded())
4131 return status;
4132
4133 // TODO: Check that the correct number of vectorSizes was provided.
4134 for (Operation *target : targets) {
4136 return mlir::emitSilenceableFailure(target->getLoc())
4137 << "Unsupported Op, cannot vectorize";
4138 }
4139 FailureOr<VectorizationResult> vectorResults =
4140 linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
4141 getVectorizeNdExtract().value_or(false),
4142 /*flatten1DDepthwiseConv=*/false,
4143 getAssumeDynamicDimsMatchVecSizes().value_or(false),
4144 getCreateNamedContraction().value_or(false));
4145 if (failed(vectorResults)) {
4146 return mlir::emitSilenceableFailure(target->getLoc())
4147 << "Attempted to vectorize, but failed";
4148 }
4149 rewriter.replaceOp(target, vectorResults->replacements);
4150 }
4151
4153}
4154
4155void transform::VectorizeOp::getEffects(
4156 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4157 consumesHandle(getTargetMutable(), effects);
4158 onlyReadsHandle(getVectorSizesMutable(), effects);
4159 modifiesPayload(effects);
4160}
4161
4162SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
4163 OpBuilder b(getContext());
4164 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
4165}
4166
4167LogicalResult transform::VectorizeOp::verify() {
4168 if (getStaticVectorSizes().size() != getScalableSizes().size())
4169 return emitOpError("expected same number of vector sizes (")
4170 << getStaticVectorSizes().size() << ") and scalable sizes ("
4171 << getScalableSizes().size() << ")";
4172 return success();
4173}
4174
4175//===----------------------------------------------------------------------===//
4176// HoistRedundantVectorTransfersOp
4177//===----------------------------------------------------------------------===//
4178
4179DiagnosedSilenceableFailure
4180transform::HoistRedundantVectorTransfersOp::applyToOne(
4181 transform::TransformRewriter &rewriter, func::FuncOp target,
4182 transform::ApplyToEachResultList &results,
4183 transform::TransformState &state) {
4184 // WARNING: This hoisting does not model parallelism and is generally
4185 // incorrect when used on distributed loops with memref semantics!
4186 // TODO: obsolete and should be retired.
4187 linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
4188 results.push_back(target);
4190}
4191
4192//===----------------------------------------------------------------------===//
4193// HoistRedundantVectorBroadcastsOp
4194//===----------------------------------------------------------------------===//
4195
4196DiagnosedSilenceableFailure
4197transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4198 transform::TransformRewriter &rewriter, mlir::Operation *target,
4199 transform::ApplyToEachResultList &results,
4200 transform::TransformState &state) {
4201 rewriter.setInsertionPoint(target);
4203 results.push_back(target);
4205}
4206
4207//===----------------------------------------------------------------------===//
4208// ConvertConv2DToImg2ColOp.
4209//===----------------------------------------------------------------------===//
4210
4211DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
4212 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4213 transform::ApplyToEachResultList &results,
4214 transform::TransformState &state) {
4215 rewriter.setInsertionPoint(target);
4216 auto maybeTransformed =
4218 target)
4219 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4220 return rewriteInIm2Col(rewriter, op);
4221 })
4222 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4223 return rewriteInIm2Col(rewriter, op);
4224 })
4225 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4226 return rewriteInIm2Col(rewriter, op);
4227 })
4228 .Case([&](linalg::Conv2DNchwFchwOp op) {
4229 return rewriteInIm2Col(rewriter, op);
4230 })
4231 .Default([&](Operation *op) {
4232 return rewriter.notifyMatchFailure(op, "not supported");
4233 });
4234 if (failed(maybeTransformed))
4235 return emitDefaultSilenceableFailure(target);
4236 // Handle to the operation producing the img2col tensor.
4237 results.push_back(maybeTransformed->first);
4238 // Handle to the operation that replaces the original convolution.
4239 results.push_back(maybeTransformed->second);
4241}
4242
4243//===----------------------------------------------------------------------===//
4244// FlattenElementwiseLinalgOp.
4245//===----------------------------------------------------------------------===//
4246
4247DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4248 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4249 transform::ApplyToEachResultList &results,
4250 transform::TransformState &state) {
4251 rewriter.setInsertionPoint(target);
4252 if (!isElementwise(target))
4253 return mlir::emitSilenceableFailure(target->getLoc())
4254 << "only elementwise flattening is supported";
4255
4256 // If rank <= 1, do nothing
4257 if (target.getNumLoops() <= 1) {
4258 results.push_back(target);
4260 }
4261
4262 // Attempt to flatten all dims to one.
4263 ReassociationIndices reassociation(target.getNumLoops());
4264 std::iota(reassociation.begin(), reassociation.end(), 0);
4265 auto maybeFlattened =
4266 collapseOpIterationDims(target, reassociation, rewriter);
4267 if (failed(maybeFlattened))
4268 return mlir::emitSilenceableFailure(target->getLoc())
4269 << "attempted to flatten, but failed";
4270 results.push_back(maybeFlattened->collapsedOp);
4271 rewriter.replaceOp(target, maybeFlattened->results);
4273}
4274
4275//===----------------------------------------------------------------------===//
4276// TransposeConv2DOp
4277//===----------------------------------------------------------------------===//
4278
4279DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4280 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4281 transform::ApplyToEachResultList &results,
4282 transform::TransformState &state) {
4283 rewriter.setInsertionPoint(target);
4284 auto maybeTransformed =
4286 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4287 return transposeConv2D(rewriter, op);
4288 })
4289 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4290 return transposeConv2D(rewriter, op);
4291 })
4292 .Default([&](Operation *op) {
4293 return rewriter.notifyMatchFailure(op, "not supported");
4294 });
4295 if (failed(maybeTransformed))
4296 return emitDefaultSilenceableFailure(target);
4297 // Handle to the new Conv2D operation with transposed filters
4298 results.push_back(*maybeTransformed);
4300}
4301
4302//===----------------------------------------------------------------------===//
4303// TransposeMatmulOp
4304//===----------------------------------------------------------------------===//
4305
4306DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4307 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4308 transform::ApplyToEachResultList &results,
4309 transform::TransformState &state) {
4310 rewriter.setInsertionPoint(target);
4311 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4312 auto maybeTransformed =
4314 .Case([&](linalg::MatmulOp op) {
4315 return transposeMatmul(rewriter, op, transposeLHS);
4316 })
4317 .Case([&](linalg::BatchMatmulOp op) {
4318 return transposeBatchMatmul(rewriter, op, transposeLHS);
4319 })
4320 .Default(failure());
4321 if (failed(maybeTransformed))
4322 return emitSilenceableFailure(target->getLoc()) << "not supported";
4323 // Handle to the new Matmul operation with transposed filters
4324 results.push_back(*maybeTransformed);
4326}
4327
4328//===----------------------------------------------------------------------===//
4329// InsertSliceToCopyOp
4330//===----------------------------------------------------------------------===//
4331template <typename OpTy>
4332static DiagnosedSilenceableFailure
4333doit(RewriterBase &rewriter, OpTy target,
4336 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4337 tensor::ParallelInsertSliceOp>() &&
4338 "wrong op type");
4339
4340 if (auto copySource =
4341 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4342 results.push_back(copySource);
4344 }
4345
4346 // If we are inside a `ParallelCombiningOp` region, temporarily set the
4347 // insertion point outside: only ops implementing ParallelCombiningOpInterface
4348 // are allowed in there.
4349 if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
4350 rewriter.setInsertionPoint(target->getParentOp());
4351
4352 Value extracted = tensor::ExtractSliceOp::create(
4353 rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
4354 target.getMixedSizes(), target.getMixedStrides());
4355 Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
4356 target.getSource(), extracted)
4357 .getResult(0);
4358 // Reset the insertion point.
4359 rewriter.setInsertionPoint(target);
4360 rewriter.replaceOpWithNewOp<OpTy>(
4361 target, copied, target.getDest(), target.getMixedOffsets(),
4362 target.getMixedSizes(), target.getMixedStrides());
4363
4364 results.push_back(copied.getDefiningOp());
4366}
4367
4368DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4369 transform::TransformRewriter &rewriter, Operation *targetOp,
4370 transform::ApplyToEachResultList &results,
4371 transform::TransformState &state) {
4372
4373 rewriter.setInsertionPoint(targetOp);
4374 if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4375 return doit(rewriter, target, results, state);
4376 if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4377 return doit(rewriter, target, results, state);
4378
4379 DiagnosedSilenceableFailure diag =
4380 emitSilenceableError()
4381 << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4382 diag.attachNote(targetOp->getLoc()) << "target op";
4383 return diag;
4384}
4385
4386//===----------------------------------------------------------------------===//
4387// MapCopyToThreadsOp
4388//===----------------------------------------------------------------------===//
4389
4390DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4391 transform::TransformRewriter &rewriter, Operation *target,
4392 transform::ApplyToEachResultList &results,
4393 transform::TransformState &state) {
4394 // Check if the op is supported.
4395 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4396 DiagnosedSilenceableFailure diag =
4397 emitSilenceableError()
4398 << "only linalg.copy and tensor.pad target ops are supported";
4399 diag.attachNote(target->getLoc()) << "target op";
4400 return diag;
4401 }
4402 assert(target->getNumResults() == 1 && "expected single result");
4403 auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
4404 if (!resultShapedType.hasStaticShape()) {
4405 DiagnosedSilenceableFailure diag =
4406 emitSilenceableError()
4407 << "only statically sized ops of rank <= 3 are supported";
4408 diag.attachNote(target->getLoc()) << "target op";
4409 return diag;
4410 }
4411
4412 // Conservatively set the minimum viable desired bitwidth alignment.
4413 int64_t desiredBitAlignment = getDesiredBitAlignment();
4414 int64_t eltBitwidth =
4415 resultShapedType.getElementType().getIntOrFloatBitWidth();
4416 if (desiredBitAlignment % eltBitwidth != 0) {
4417 desiredBitAlignment = eltBitwidth;
4418 }
4419
4420 gpu::CopyMappingInfo mapping(
4421 /*ctx=*/getContext(),
4422 /*totalNumThreads=*/getTotalNumThreads(),
4423 /*alignment=*/desiredBitAlignment,
4424 /*sizes=*/resultShapedType.getShape(),
4425 /*favorPredication=*/false,
4426 /*elementalBitwidth=*/
4427 resultShapedType.getElementType().getIntOrFloatBitWidth());
4428 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4429 DiagnosedSilenceableFailure diag =
4430 emitSilenceableError()
4431 << "too few threads to map copy op to threads on the most minor "
4432 "dimension, given alignment and vector size constraints, try "
4433 "smaller tile size of mapping to more threads";
4434 diag.attachNote(target->getLoc()) << "target op";
4435 return diag;
4436 }
4437
4438 // OpBuilder only used to compute attributes.
4439 OpBuilder b(getContext());
4440 scf::SCFTilingResult tilingResult;
4441 DiagnosedSilenceableFailure diag = tileToForallOpImpl(
4442 /*rewriter=*/rewriter,
4443 /*state=*/state,
4444 /*transformOp=*/*this,
4445 /*target=*/target,
4446 /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
4447 /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
4448 /*mapping=*/b.getArrayAttr(mapping.threadMapping),
4449 /*tilingResult=*/tilingResult);
4450 if (!diag.succeeded())
4451 return diag;
4452
4453 results.push_back(tilingResult.loops.front());
4454 for (auto *op : tilingResult.tiledOps)
4455 results.push_back(op);
4457}
4458
4459//===----------------------------------------------------------------------===//
4460// WinogradConv2DOp
4461//===----------------------------------------------------------------------===//
4462
4463DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4464 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4465 transform::ApplyToEachResultList &results,
4466 transform::TransformState &state) {
4467 rewriter.setInsertionPoint(target);
4468 FailureOr<Operation *> maybeTransformed = failure();
4469 bool supported = TypeSwitch<Operation *, bool>(target)
4470 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4471 maybeTransformed =
4472 winogradConv2D(rewriter, op, getFmr());
4473 return true;
4474 })
4475 .Default([&](Operation *op) { return false; });
4476
4477 if (!supported) {
4478 return emitSilenceableError()
4479 << "this operation is not supported to convert to Winograd Conv2D";
4480 }
4481
4482 if (failed(maybeTransformed)) {
4483 return emitSilenceableError() << "apply Winograd Conv2D failed";
4484 }
4485
4486 results.push_back(*maybeTransformed);
4488}
4489
4490DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4491 transform::TransformRewriter &rewriter, Operation *target,
4492 transform::ApplyToEachResultList &results,
4493 transform::TransformState &state) {
4494 rewriter.setInsertionPoint(target);
4495 FailureOr<Operation *> maybeTransformed = failure();
4496 bool supported =
4498 .Case([&](linalg::WinogradFilterTransformOp op) {
4499 maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
4500 return true;
4501 })
4502 .Case([&](linalg::WinogradInputTransformOp op) {
4503 maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
4504 return true;
4505 })
4506 .Case([&](linalg::WinogradOutputTransformOp op) {
4507 maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
4508 return true;
4509 })
4510 .Default(false);
4511
4512 if (!supported) {
4513 DiagnosedSilenceableFailure diag =
4514 emitSilenceableError()
4515 << "this operation is not supported to decompose into other operations";
4516 diag.attachNote(target->getLoc()) << "target op";
4517 return diag;
4518 }
4519
4520 if (failed(maybeTransformed)) {
4521 DiagnosedSilenceableFailure diag =
4522 emitSilenceableError() << "decompose Winograd operations failed";
4523 diag.attachNote(target->getLoc()) << "target op";
4524 return diag;
4525 }
4526
4527 results.push_back(*maybeTransformed);
4529}
4530
4531#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4532
4533#define GET_OP_CLASSES
4534#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized. TODO: Replace this with a g...
ArrayAttr()
static bool mayBeRead(OpOperand &operand)
Return true if the operand may be read from by its owner.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static bool sameOrEquivalentIterArg(Value src, Value dst)
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tileSizes, Type)
static DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static llvm::ManagedStatic< PassManagerOptions > options
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Base type for affine expression.
Definition AffineExpr.h:68
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition Builders.cpp:310
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the error.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:339
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition Builders.h:318
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:454
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getOpResult(unsigned idx)
Definition Operation.h:447
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:560
void setOperand(unsigned idx, Value value)
Definition Operation.h:377
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:586
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
operand_type_range getOperandTypes()
Definition Operation.h:423
result_type_range getResultTypes()
Definition Operation.h:454
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:289
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
user_range getUsers()
Returns a range of all users.
Definition Operation.h:899
result_range getOpResults()
Definition Operation.h:446
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
Type front()
Return first type in the range.
Definition TypeRange.h:164
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation, bool PoisonPaddingOk=false)
Patterns to bubble up or down data layout ops across other operations.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
Definition Padding.cpp:244
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice + copy.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Canonicalization patterns relevant to apply after tiling patterns.
Definition Tiling.cpp:866
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition Transforms.h:489
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors and memref.
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:217
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition Tiling.cpp:236
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition Tiling.cpp:156
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition Tiling.cpp:106
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition Hoisting.cpp:89
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition Hoisting.cpp:198
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition Split.cpp:68
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition Tiling.cpp:262
FailureOr< LinalgOp > downscaleSizeOneWindowedConvolution(RewriterBase &rewriter, LinalgOp op)
Rewrite convolution/pooling/depthwise ops with size-1 window dimensions into lower-dimensional ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition SCF.cpp:675
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:60
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SmallVector< IntTy > extractFromIntegerArrayAttr(Attribute attr)
Extract integer values from the assumed ArrayAttr of IntegerAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition Builders.h:287
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition Transforms.h:521
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition Transforms.h:409
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition Transforms.h:456
LinalgPromotionOptions & setAlignment(unsigned align)
Definition Transforms.h:433
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition Transforms.h:440
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition Transforms.h:398
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition Transforms.h:420
LinalgPromotionOptions & setUseOriginalSubviewSize(bool originalSize)
Definition Transforms.h:427
LinalgPromotionOptions & setUseAlloca(bool use)
Definition Transforms.h:446
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition Transforms.h:466