MLIR 23.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
40#include "mlir/Support/LLVM.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/ScopeExit.h"
44#include "llvm/ADT/SmallPtrSet.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/DebugLog.h"
48#include "llvm/Support/LogicalResult.h"
49#include <type_traits>
50
51using namespace mlir;
52using namespace mlir::linalg;
53using namespace mlir::transform;
54
55#define DEBUG_TYPE "linalg-transforms"
56
57/// Attempts to apply the pattern specified as template argument to the given
58/// operation. The pattern is expected to have a `returningMatchAndRewrite`
59/// function that returns the "main" result or failure. Returns failure if the
60/// pattern failed to apply. Extra arguments are forwarded to the pattern
61/// constructor.
62template <typename PatternTy, typename... Args>
63static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
64 // Check if the given operation has the type expected by the pattern.
65 using OpTy = typename llvm::function_traits<
66 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
67 auto op = dyn_cast<OpTy>(operation);
68 if (!op)
69 return failure();
70
71 // Apply the pattern directly to the op.
72 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
73 // We want to discourage direct use of PatternRewriter in APIs but In this
74 // very specific case, an IRRewriter is not enough.
75 PatternRewriter rewriter(operation->getContext());
76 rewriter.setInsertionPoint(operation);
77 auto result = pattern.returningMatchAndRewrite(op, rewriter);
78 if (failed(result))
79 return failure();
80 return cast<LinalgOp>(result->getOperation());
81}
82
83/// Assuming that `ofr` is an index attr or a param of index type
84/// or a transform dialect handle mapped to exactly one op
85/// with one index result, return that value.
87 transform::TransformState &state, TransformOpInterface transformOp,
89 for (OpFoldResult ofr : ofrs) {
90 if (auto attr = dyn_cast<Attribute>(ofr)) {
91 if (!isa<IntegerAttr>(attr))
92 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
93 result.push_back(ofr);
94 continue;
95 }
96
97 Value transformValue = cast<Value>(ofr);
98 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
99 ArrayRef<Attribute> params = state.getParams(transformValue);
100 if (params.size() != 1)
101 return transformOp.emitDefiniteFailure()
102 << "requires exactly one parameter associated";
103 result.push_back(params[0]);
104 continue;
105 }
106
107 auto payloadOps = state.getPayloadOps(transformValue);
108 if (!llvm::hasSingleElement(payloadOps)) {
110 transformOp.emitSilenceableError()
111 << "handle must be mapped to exactly one payload op";
112 diag.attachNote(transformValue.getLoc())
113 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
114 return diag;
115 }
116
117 Operation *op = *payloadOps.begin();
118 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
120 transformOp.emitSilenceableError()
121 << "payload op must have exactly 1 index result";
122 diag.attachNote(op->getLoc())
123 << "has " << op->getNumResults() << " results";
124 return diag;
125 }
126 result.push_back(op->getResult(0));
127 }
128
130}
131
132// Given a list of params that are index attrs or a list of OpFoldResults
133// that are either index attrs or op handles, return a list of OpFoldResults
134// of index attrs or a list of OpFoldResults where all op handles are
135// replaced with the first (and only) OpResult of that payload op.
136// (There must be exactly one parameter associated with the AnyParamType or
137// one mapped payload op which must have exactly one index result.)
139 transform::TransformState &state, TransformOpInterface transformOp,
140 SmallVector<OpFoldResult> &result, Value packedHandle) {
141 if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
142 ArrayRef<Attribute> params = state.getParams(packedHandle);
143 for (auto param : params) {
144 if (!isa<IntegerAttr>(param))
145 return transformOp.emitDefiniteFailure()
146 << "expected the parameter to be associated with an integer "
147 "attribute";
148 result.push_back(param);
149 }
151 }
152
153 for (Operation *op : state.getPayloadOps(packedHandle)) {
154 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
156 transformOp.emitSilenceableError()
157 << "payload op must have exactly 1 index result";
158 diag.attachNote(op->getLoc())
159 << "has " << op->getNumResults() << " results";
160 return diag;
161 }
162 result.push_back(op->getResult(0));
163 }
164
166}
167
168/// When possible, converts each `OpFoldResult` in `mixedResult` to
169/// an integer if the value can be statically inferred. If a result
170/// is a `Value` then it must be either a `ParamType` or a handle
171/// to an a constant like op.
173 TransformState &state, TransformOpInterface &transformOp,
174 ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
175 for (OpFoldResult paramOrHandle : mixedResults) {
176 if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
177 reified.push_back(cast<IntegerAttr>(attr).getInt());
178 continue;
179 }
180 if (isa<TransformParamTypeInterface>(
181 cast<Value>(paramOrHandle).getType())) {
182 ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
183 if (params.size() != 1)
184 return transformOp.emitSilenceableError() << "expected a single param";
185 reified.push_back(
186 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
187 continue;
188 }
189
190 Value handle = cast<Value>(paramOrHandle);
191 if (!isa<TransformHandleTypeInterface>(handle.getType()))
192 return transformOp.emitSilenceableError() << "unexpected value handle";
193 auto payload = state.getPayloadOps(handle);
194 if (!llvm::hasSingleElement(payload))
195 return transformOp.emitSilenceableError()
196 << "requires param or handle that is mapped to 1 payload op";
197
198 Operation *paramOrHandlePayloadOp = *payload.begin();
199 if (paramOrHandlePayloadOp->getNumResults() != 1 ||
200 !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
201 return transformOp.emitSilenceableError()
202 << "requires param or handle to be result of op with 1 index "
203 "result";
204 }
205
206 IntegerAttr attr;
207 if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
208 return transformOp.emitSilenceableError()
209 << "requires param or handle to be the result of a constant like "
210 "op";
211
212 reified.push_back(attr.getInt());
213 }
215}
216
217//===----------------------------------------------------------------------===//
218// Apply...PatternsOp
219//===----------------------------------------------------------------------===//
220
221void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
222 RewritePatternSet &patterns) {
224}
225
226void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
227 RewritePatternSet &patterns) {
229}
230
231void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
232 RewritePatternSet &patterns) {
234}
235
236void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
237 RewritePatternSet &patterns) {
240}
241
242void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
243 RewritePatternSet &patterns) {
245 options.rankReductionStrategy =
248}
249
250void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
251 RewritePatternSet &patterns) {
253}
254
255void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
256 RewritePatternSet &patterns) {
258}
259
260void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
261 RewritePatternSet &patterns) {
263}
264
265void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
266 RewritePatternSet &patterns) {
268}
269
270void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
271 RewritePatternSet &patterns) {
273}
274
275void transform::ApplyDataLayoutPropagationPatternsOp::populatePatterns(
276 RewritePatternSet &patterns) {
277 linalg::ControlPropagationFn defaultControlFn = [](OpOperand *operand) {
278 return true;
279 };
280 linalg::populateDataLayoutPropagationPatterns(patterns, defaultControlFn,
281 getPoisonPadding());
282}
283
284void transform::ApplyExtractSliceSinkingPatternsOp::populatePatterns(
285 RewritePatternSet &patterns) {
286 linalg::ControlPropagationFn defaultControlFn =
287 [](OpOperand *opOperand) -> bool {
288 Operation *producer = opOperand->get().getDefiningOp();
289 Operation *consumer = opOperand->getOwner();
290 return consumer->getBlock() == producer->getBlock();
291 };
292 linalg::populateExtractSliceSinkingPatterns(patterns, defaultControlFn);
293}
294
295//===----------------------------------------------------------------------===//
296// BufferizeToAllocationOp
297//===----------------------------------------------------------------------===//
298
299namespace {
300class NewOpsListener : public RewriterBase::ForwardingListener {
301public:
303
304 SmallVector<Operation *> getNewOps() const {
305 return SmallVector<Operation *>(newOps.begin(), newOps.end());
306 }
307
308private:
309 void notifyOperationInserted(Operation *op,
310 OpBuilder::InsertPoint previous) override {
311 ForwardingListener::notifyOperationInserted(op, previous);
312 // We only care about newly created ops.
313 if (previous.isSet())
314 return;
315 auto inserted = newOps.insert(op);
316 (void)inserted;
317 assert(inserted.second && "expected newly created op");
318 }
319
320 void notifyOperationErased(Operation *op) override {
321 ForwardingListener::notifyOperationErased(op);
322 op->walk([&](Operation *op) { newOps.erase(op); });
323 }
324
326};
327} // namespace
328
329DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
332 // Attach listener to keep track of newly created ops.
333 OpBuilder::Listener *previousListener = rewriter.getListener();
334 llvm::scope_exit resetListener(
335 [&]() { rewriter.setListener(previousListener); });
336 NewOpsListener newOpsListener(previousListener);
337 rewriter.setListener(&newOpsListener);
338
340 if (getMemcpyOp() == "bufferization.materialize_in_destination") {
341 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
342 MaterializeInDestination;
343 } else if (getMemcpyOp() == "memref.copy") {
344 options.memcpyOp =
346 } else if (getMemcpyOp() == "linalg.copy") {
347 options.memcpyOp =
349 } else {
350 llvm_unreachable("invalid memcpy op");
351 }
352 if (getAllocOp() == "memref.alloc") {
353 options.allocOp =
355 } else if (getAllocOp() == "memref.alloca") {
356 options.allocOp =
358 } else {
359 llvm_unreachable("invalid alloc op");
360 }
361 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
362 options.emitDealloc = getEmitDealloc();
363
364 // Bufferize ops.
365 Attribute memorySpace =
366 getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
367 SmallVector<Value> allocatedBuffers;
368 for (Operation *op : state.getPayloadOps(getTarget())) {
369 Value buffer =
370 linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
371 if (!buffer) {
372 DiagnosedSilenceableFailure diag = emitSilenceableError()
373 << "failed to bufferize operation";
374 diag.attachNote(op->getLoc()) << "target payload op";
375 return diag;
376 }
377 allocatedBuffers.push_back(buffer);
378 }
379
380 // Set results.
381 results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
382 results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
384}
385
386void transform::BufferizeToAllocationOp::getEffects(
388 if (getBufferizeDestinationOnly()) {
389 // The destination is replaced with a newly allocated buffer, but the op
390 // itself remains in place.
391 onlyReadsHandle(getTargetMutable(), effects);
392 } else {
393 consumesHandle(getTargetMutable(), effects);
394 }
395 producesHandle(getOperation()->getOpResults(), effects);
396 modifiesPayload(effects);
397}
398
399LogicalResult transform::BufferizeToAllocationOp::verify() {
400 if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
401 getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
402 return emitOpError() << "unsupported memcpy op";
403 if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
404 return emitOpError() << "unsupported alloc op";
405 return success();
406}
407
408//===----------------------------------------------------------------------===//
409// PromoteTensorOp
410//===----------------------------------------------------------------------===//
411
412/// Return true if the operand may be read from by its owner. This is currently
413/// very conservative and only looks inside linalg operations to prevent
414/// unintentional data loss.
415static bool mayBeRead(OpOperand &operand) {
416 auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
417
418 // Be conservative about ops we cannot analyze deeper.
419 if (!linalgOp)
420 return true;
421
422 // Look inside linalg ops.
423 Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
424 return !blockArgument.use_empty();
425}
426
427/// Return true if the value may be read through any of its uses.
428static bool mayBeRead(Value value) {
429 // If the value has a reference semantics, it
430 // may be read through any alias...
431 if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
432 return true;
433 return llvm::any_of(value.getUses(),
434 static_cast<bool (&)(OpOperand &)>(mayBeRead));
435}
436
438transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
441 SmallVector<Value> promoted;
442 for (Value tensor : state.getPayloadValues(getTensor())) {
443 auto type = dyn_cast<RankedTensorType>(tensor.getType());
444 if (!type) {
445 return emitSilenceableError() << "non-tensor type: " << tensor;
446 }
447
448 Operation *definingOp = tensor.getDefiningOp();
449 if (definingOp)
450 rewriter.setInsertionPointAfter(definingOp);
451 else
452 rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
453
454 // Check this before we emit operations using this value.
455 bool needsMaterialization = mayBeRead(tensor);
456
457 SmallVector<Value> dynamicDims;
459 for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
460 if (!ShapedType::isDynamic(dim))
461 continue;
462 Value cst =
463 arith::ConstantIndexOp::create(rewriter, tensor.getLoc(), pos);
464 auto dimOp =
465 tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst);
466 preservedOps.insert(dimOp);
467 dynamicDims.push_back(dimOp);
468 }
469 auto allocation = bufferization::AllocTensorOp::create(
470 rewriter, tensor.getLoc(), type, dynamicDims);
471 // Set memory space if provided.
472 if (getMemorySpaceAttr())
473 allocation.setMemorySpaceAttr(getMemorySpaceAttr());
474 Value allocated = allocation;
475
476 // Only insert a materialization (typically bufferizes to a copy) when the
477 // value may be read from.
478 if (needsMaterialization) {
479 auto copy = bufferization::MaterializeInDestinationOp::create(
480 rewriter, tensor.getLoc(), tensor, allocated);
481 preservedOps.insert(copy);
482 promoted.push_back(copy.getResult());
483 } else {
484 promoted.push_back(allocated);
485 }
486 rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
487 }
488 results.setValues(cast<OpResult>(getPromoted()), promoted);
490}
491
492void transform::PromoteTensorOp::getEffects(
494 transform::onlyReadsHandle(getTensorMutable(), effects);
495 transform::producesHandle(getOperation()->getOpResults(), effects);
497}
498
499//===----------------------------------------------------------------------===//
500// DecomposeOp
501//===----------------------------------------------------------------------===//
502
504transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
505 LinalgOp target,
508 FailureOr<linalg::LinalgOp> res =
510 if (succeeded(res)) {
511 results.push_back(*res);
513 }
514 return emitDefaultSilenceableFailure(target);
515}
516
517//===----------------------------------------------------------------------===//
518// DecomposeInterfaceOp
519//===----------------------------------------------------------------------===//
520
521// Decompose the target operation if it implements the AggregatedOpInterface.
522// Push the decomposed operations (the ones that replaces the values produced by
523// \p target) in the `results`.
524DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
528 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
529 if (!decomposableOp) {
531 "payload is not a decomposable op"));
532 return emitDefaultSilenceableFailure(target);
533 }
534
535 FailureOr<SmallVector<Value>> maybeNewResults =
536 decomposableOp.decomposeOperation(rewriter);
537 if (failed(maybeNewResults))
538 return emitDefaultSilenceableFailure(target);
539
540 rewriter.replaceOp(decomposableOp, *maybeNewResults);
541 for (Value val : *maybeNewResults) {
542 Operation *definition = val.getDefiningOp();
543 if (definition)
544 results.push_back(definition);
545 }
547}
548
549//===----------------------------------------------------------------------===//
550// EliminateLinalgOpAnchoredEmptyTensorsOp
551//===----------------------------------------------------------------------===//
552
553void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
555 onlyReadsHandle(getTargetMutable(), effects);
556 modifiesPayload(effects);
557}
558
560transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
561 transform::TransformRewriter &rewriter, TransformResults &transformResults,
562 TransformState &state) {
564 options.allowReturnAllocsFromLoops = true;
565
566 for (Operation *target : state.getPayloadOps(getTarget())) {
568 if (failed(analyzeOp(target, state)))
569 return mlir::emitSilenceableFailure(target->getLoc())
570 << "failed to analyze op";
572 rewriter, target, state)))
573 return mlir::emitSilenceableFailure(target->getLoc())
574 << "failed to eliminate LinalgOp anchored tensor.empty ops";
575 }
577}
578
579//===----------------------------------------------------------------------===//
580// FuseOp
581//===----------------------------------------------------------------------===//
582
583void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
584 TypeRange loopTypes, Value target,
585 ArrayRef<int64_t> staticTileSizes,
586 ArrayRef<int64_t> staticTileInterchange,
587 bool applyCleanup, bool useForall) {
588 return build(
589 builder, result, loopTypes,
590 /*target=*/target,
591 /*mixedTileSizes=*/
592 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
593 /*mixedTileInterchange=*/
594 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
595 applyCleanup, useForall);
596}
597
598void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
599 Value target, ArrayRef<int64_t> staticTileSizes,
600 ArrayRef<int64_t> staticTileInterchange,
601 bool applyCleanup, bool useForall) {
602 return build(
603 builder, result,
604 /*target=*/target,
605 /*mixedTileSizes=*/
606 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
607 /*mixedTileInterchange=*/
608 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
609 applyCleanup, useForall);
610}
611
612void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
614 ArrayRef<OpFoldResult> mixedTileSizes,
615 ArrayRef<OpFoldResult> mixedTileInterchange,
616 bool applyCleanup, bool useForall) {
617 // Loop types are automaticaly splat by the callee, setting up one is
618 // enough.
619 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
620 build(builder, result, loopTypes, target, mixedTileSizes,
621 mixedTileInterchange, applyCleanup, useForall);
622}
623
624void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
625 TypeRange loopTypes, Value target,
626 ArrayRef<OpFoldResult> mixedTileSizes,
627 ArrayRef<OpFoldResult> mixedTileInterchange,
628 bool applyCleanup, bool useForall) {
629 SmallVector<int64_t> staticTileSizes;
630 SmallVector<Value> dynamicTileSizes;
631 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
632 SmallVector<int64_t> staticTileInterchange;
633 SmallVector<Value> dynamicTileInterchange;
634 dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
635 staticTileInterchange);
636 // Call the default builder which sets up the proper operands segment sizes
637 // attributes for multiple variadic operands. In the absence of this,
638 // horrible bugs ensue.
639 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
640 auto staticTileInterchangeAttr =
641 builder.getDenseI64ArrayAttr(staticTileInterchange);
642 unsigned numExpectedLoops =
643 useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
644 SmallVector<Type> resultTypes;
645 resultTypes.reserve(numExpectedLoops);
646 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
647 "expected one loop type or as many as loops");
648 if (loopTypes.size() == 1)
649 resultTypes.append(numExpectedLoops, loopTypes[0]);
650 else
651 llvm::append_range(resultTypes, loopTypes);
652 build(builder, result, /*transformed=*/target.getType(),
653 /*loops=*/resultTypes,
654 /*target=*/target,
655 /*tile_sizes=*/dynamicTileSizes,
656 /*tile_interchange=*/dynamicTileInterchange,
657 /*packed_tile_sizes=*/Value(),
658 /*static_tile_sizes=*/staticTileSizesAttr,
659 /*static_tile_interchange=*/staticTileInterchangeAttr,
660 /*apply_cleanup=*/applyCleanup,
661 /*use_forall=*/useForall);
662}
663
664/// Apply a tiling transformation to all payload ops and store both the
665/// tiled operation as well as the created tile loops.
666template <typename Range>
667static LogicalResult applyTilingToAll(
668 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
669 unsigned numLoops, transform::TransformResults &transformResults,
670 bool packedResults,
671 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
672 applyFn) {
673 SmallVector<Operation *> tiledLinalgOps;
674 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
675 size_t numTargets = llvm::range_size(payloadOps);
676
677 for (Operation *target : payloadOps) {
678 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
679 if (!tilingInterfaceOp)
680 return transformOp->emitError("only TilingInterface ops are supported");
681
682 rewriter.setInsertionPoint(target);
683 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
684 applyFn(tilingInterfaceOp);
685 if (failed(tiledResults))
686 return failure();
687
688 // Perform the replacement of tiled and fused values.
689 SmallVector<Operation *> opsToReplace{target};
690 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
691 for (Operation *toReplace : opsToReplace) {
692 for (OpResult res : toReplace->getResults())
693 if (auto replacement = tiledResults->replacements.lookup(res))
694 rewriter.replaceAllUsesWith(res, replacement);
695 if (toReplace->use_empty()) {
696 rewriter.eraseOp(toReplace);
697 }
698 }
699
700 // Report back the relevant handles to the transform op.
701 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
702 assert(tiledResults->loops.size() == numLoops &&
703 "Mismatched number of loops, tile and fuse transform should have "
704 "failed");
705 for (unsigned int i = 0; i < numLoops; ++i)
706 loopOps[i].push_back(tiledResults->loops[i]);
707 }
708
709 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
710 if (packedResults) {
711 // In case of packed results, all created loops are assigned to a single
712 // handle. Loops are returned in order of targets such as:
713 // %loops_handle = {
714 // target0:loop0, ..., target0:loopN,
715 // target1:loop0, ..., target1:loopN,
716 // ... }
717 SmallVector<Operation *> flattenedLoopOps;
718 for (unsigned int idx = 0; idx < numTargets; ++idx)
719 for (unsigned int i = 0; i < numLoops; ++i)
720 flattenedLoopOps.push_back(loopOps[i][idx]);
721 transformResults.set(transformOp->getOpResult(1), flattenedLoopOps);
722 } else {
723 for (unsigned int i = 0; i < numLoops; ++i)
724 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
725 }
726
727 return success();
728}
729
731transform::FuseOp::apply(transform::TransformRewriter &rewriter,
732 mlir::transform::TransformResults &transformResults,
734 auto transformOp = cast<TransformOpInterface>(getOperation());
735
736 SmallVector<OpFoldResult> mixedTileSizes;
738 getPackedTileSizes()
740 state, transformOp, mixedTileSizes, getPackedTileSizes())
742 state, transformOp, mixedTileSizes, getMixedTileSizes());
743 if (!status.succeeded())
744 return status;
745 SmallVector<int64_t> tileInterchange;
747 state, transformOp, getMixedTileInterchange(), tileInterchange);
748 if (!status.succeeded())
749 return status;
750
751 scf::SCFTilingOptions tilingOptions;
752 tilingOptions.interchangeVector = tileInterchange;
753 bool useForall = getUseForall();
754 tilingOptions.setLoopType(useForall
755 ? scf::SCFTilingOptions::LoopType::ForallOp
756 : scf::SCFTilingOptions::LoopType::ForOp);
757 tilingOptions = tilingOptions.setTileSizes(mixedTileSizes);
758 scf::SCFTileAndFuseOptions tileAndFuseOptions;
759 tileAndFuseOptions.tilingOptions = tilingOptions;
760
761 if (getApplyCleanup()) {
762 MLIRContext *context = rewriter.getContext();
763 RewritePatternSet patterns(context);
764 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
767 tileAndFuseOptions.cleanupPatterns = std::move(patterns);
768 }
769
770 size_t numLoops;
771 if (useForall) {
772 numLoops = 1;
773 } else {
774 numLoops = llvm::count_if(mixedTileSizes, [](OpFoldResult ofr) {
775 auto attr = dyn_cast<Attribute>(ofr);
776 if (!attr)
777 return true;
778 return cast<IntegerAttr>(attr).getInt() != 0;
779 });
780 }
781 LogicalResult result = applyTilingToAll(
782 rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
783 transformResults, /*packedResults=*/getPackedTileSizes() != nullptr,
784 [&](TilingInterface tilingInterfaceOp)
785 -> FailureOr<scf::SCFTileAndFuseResult> {
786 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
787 tileAndFuseOptions);
788 });
791}
792
793LogicalResult transform::FuseOp::verify() {
794 bool hasPackedTiles = getPackedTileSizes() != nullptr;
795 if (!getMixedTileSizes().empty() && hasPackedTiles)
796 return emitOpError(
797 "tile_sizes and packed_tile_sizes are mutually exclusive");
798
799 auto iterspace_rank = getStaticTileSizes().size();
800 ArrayRef<int64_t> permutation = getStaticTileInterchange();
801 if (permutation.size() > iterspace_rank)
802 return emitOpError()
803 << "interchange length exceeds iteration space dimensions ("
804 << iterspace_rank << "), found " << getTileInterchange();
805 SmallVector<bool> seen(iterspace_rank, false);
806 for (int64_t v : permutation) {
807 if (!ShapedType::isDynamic(v)) {
808 if (v < 0 || v >= static_cast<int64_t>(iterspace_rank))
809 return emitOpError() << "expects interchange values to be in range [0, "
810 << iterspace_rank << "), found: " << v;
811 if (seen[v])
812 return emitOpError() << "found duplicate interchange value: " << v;
813 seen[v] = true;
814 }
815 }
816
817 ArrayRef<int64_t> sizes = getStaticTileSizes();
818 size_t numExpectedLoops = getUseForall() || hasPackedTiles
819 ? 1
820 : sizes.size() - llvm::count(sizes, 0);
821 if (numExpectedLoops != getNumResults() - 1)
822 return emitOpError() << "expects " << numExpectedLoops << " loop results";
823
824 return success();
825}
826
827SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
828 return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext());
829}
830
831SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
832 return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
833 getContext());
834}
835
836void transform::FuseOp::getEffects(
838 consumesHandle(getTargetMutable(), effects);
839 onlyReadsHandle(getTileSizesMutable(), effects);
840 onlyReadsHandle(getPackedTileSizesMutable(), effects);
841 onlyReadsHandle(getTileInterchangeMutable(), effects);
842 producesHandle(getOperation()->getOpResults(), effects);
843 modifiesPayload(effects);
844}
845
846//===----------------------------------------------------------------------===//
847// FuseIntoContainingOp
848//===----------------------------------------------------------------------===//
849
850void transform::FuseIntoContainingOp::build(OpBuilder &builder,
852 Value producerOp,
853 Value containingOp) {
854 result.addOperands({producerOp, containingOp});
855 auto resultType = transform::AnyOpType::get(builder.getContext());
856 result.addTypes({resultType, resultType});
857}
858
859/// Add new operands to the forall op for users of the producerOp
860/// that are dominated by the containing scf.forall op.
862 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
863 Operation *containingOp, TilingResult &tileAndFuseResult,
864 int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
866
867 // Count number of users not including the containing op
868 SetVector<Operation *> dominatedUsers;
869 DominanceInfo domInfo(containingOp);
870 for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
871 if (!containingOp->isAncestor(user) &&
872 (domInfo.dominates(containingOp, user))) {
873 dominatedUsers.insert(user);
874 }
875 }
876 if (dominatedUsers.empty())
877 return nullptr;
878
879 // Create new scf.forall op
880 auto forallOp = cast<scf::ForallOp>(containingOp);
881 OpBuilder::InsertionGuard g(rewriter);
882 rewriter.setInsertionPoint(forallOp);
883
884 // Get new output
885 Location loc = forallOp.getLoc();
886 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
887 if (!genericOp)
888 return nullptr;
889 SmallVector<Value> outputs = genericOp.getOutputs();
890 SmallVector<Value> newOuts(forallOp.getOutputs());
891 newOuts.push_back(outputs[resultNumber]);
892
893 // Create new scf.forall op
894 auto newforallOp = scf::ForallOp::create(
895 rewriter, loc, forallOp.getMixedLowerBound(),
896 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
897 forallOp.getMapping());
898 rewriter.eraseBlock(newforallOp.getBody());
899 newforallOp.getRegion().takeBody(forallOp.getRegion());
900
901 // Add additional block argument for new value being returned
902 // and replaces all uses of the new output with corresponding bbArg
903 // inside the scf.forall to enable fusion into this new scf.forall.
904 newforallOp.getBody()->addArgument(newOuts.back().getType(),
905 newOuts.back().getLoc());
906 auto bbArgs = newforallOp.getBody()->getArguments();
907 rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
908 [&](OpOperand &use) {
909 Operation *op = use.getOwner();
910 return newforallOp->isProperAncestor(op);
911 });
912
913 // Fix terminator
914 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
915 SmallVector<Operation *> yieldingOps = llvm::map_to_vector<4>(
916 terminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
917 Operation *firstYieldOp = yieldingOps.front();
918 rewriter.setInsertionPoint(firstYieldOp);
919 Value src = tileAndFuseResult.tiledValues[0];
920 Value dst = newforallOp.getRegionIterArgs().back();
921 SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
922 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->getLoc(), src,
923 dst, offsets, sizes, strides);
924
925 for (auto result : llvm::enumerate(forallOp.getResults())) {
926 rewriter.replaceAllUsesWith(result.value(),
927 newforallOp->getResult(result.index()));
928 }
929 rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
930 newforallOp->getResults().back(),
931 [&](OpOperand &use) {
932 Operation *user = use.getOwner();
933 return dominatedUsers.contains(user);
934 });
935 return newforallOp;
936}
937
938/// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
939/// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
940/// outer loop. To determine the second condition, this function iterates
941/// using a worklist over the enclosing loops, trying to find 'src' in any of
942/// the parent loop's iter args.
943static bool sameOrEquivalentIterArg(Value src, Value dst) {
944 // Stack like vector containing possible iterArgs candidates. The first one
945 // is dst, and we will transverse the IR from there.
946 SmallVector<Value> destWorklist;
947 destWorklist.push_back(dst);
948
949 while (!destWorklist.empty()) {
950 Value currentDst = destWorklist.pop_back_val();
951
952 // We have found the same operand in some iter arg in the loop structure,
953 // so src and dst are equivalent.
954 if (src == currentDst)
955 return true;
956
957 // The operands are not equivalent, look for enclosing loops over
958 // currentDst.
959 auto bbArg = dyn_cast<BlockArgument>(currentDst);
960 if (!bbArg)
961 continue;
962
963 Block *parentBlock = bbArg.getOwner();
964 assert(parentBlock && "unlinked block argument");
965
966 Operation *parentOp = parentBlock->getParentOp();
967 assert(parentOp && "expected block argument with parent operation");
968
969 // Check if parent is loop-like. If it's not, do not add it to the worklist.
970 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
971 if (!parentLoop)
972 continue;
973
974 for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
975 // No need to check for null as innerIterArg is tied to parentLoop.
976 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
977 Value loopBlockArgument =
978 parentLoop->getOperand(operand->getOperandNumber());
979 destWorklist.push_back(loopBlockArgument);
980 }
981 }
982
983 return false;
984}
985
986/// Find the first "extract" user of `producerOp` and tile it right before its
987/// use. The tiled op is fused under the `containingOp`.
988/// Return this fused op on success or nullptr if anything fails.
989/// If tiled op has uses that are dominated by `containingOp`, return
990/// a new `containingOp` with results of the fused op appended to
991/// results of the `containingOp` or nullptr if there are no dominated uses.
992static std::tuple<SmallVector<Operation *>, Operation *>
994 Operation *producerOp, Operation *containingOp) {
995 LDBG() << "Try to fuse a direct extract use";
996 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
997 if (!tileableProducer) {
998 diag.attachNote(producerOp->getLoc())
999 << "producer is not a TileableInterface: " << *producerOp;
1000 return {};
1001 }
1002
1003 // Search the producer slices accessed within the containing operation.
1004 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
1005 // evolve into an interface.
1006 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
1007 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1008 return sliceOp && containingOp->isProperAncestor(sliceOp);
1009 });
1010
1011 // Find a fusion opportunity.
1012 if (it == tileableProducer->getUsers().end()) {
1013 diag.attachNote(tileableProducer->getLoc())
1014 << "could not find fusion opportunity for: " << *tileableProducer;
1015 return {};
1016 }
1017 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
1018
1019 // Try to fuse the producer in-place.
1020 OpBuilder::InsertionGuard guard(rewriter);
1021 rewriter.setInsertionPoint(sliceOpToTile);
1022
1023 // Clone the producer inside the consumer and try to update the producer init
1024 // operands using the loop bbArgs if applicable. More precisely, if the bbArg
1025 // of the container loop points to a value that it is used by the consumer op,
1026 // then, instead of using such value on the consumer, use the value coming
1027 // from the bbArg instead. This allows to reuse the output tensor (instead of
1028 // creating a new one) of the container when both producer and container write
1029 // to the same output.
1030 if (LoopLikeOpInterface containerLoop =
1031 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
1032 Operation *clone = rewriter.clone(*producerOp);
1033 rewriter.modifyOpInPlace(clone, [&]() {
1034 // Iterate over the outputs of the producer and over the loop bbArgs and
1035 // check if any bbArg points to the same value as the producer output. In
1036 // such case, make the producer output point to the bbArg directly.
1037 auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone);
1038 if (!dpsInterface)
1039 return;
1040
1041 for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
1042 Value producerOperand =
1043 clone->getOperand(initOperandPtr.getOperandNumber());
1044 for (BlockArgument containerIterArg :
1045 containerLoop.getRegionIterArgs()) {
1046 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
1047 Value consumerOperand =
1048 containerLoop->getOperand(bbArg->getOperandNumber());
1049 // The producer has the same init as the loop bbArg, use it.
1050 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
1051 initOperandPtr.set(containerIterArg);
1052 }
1053 }
1054 }
1055 });
1056
1057 tileableProducer = dyn_cast<TilingInterface>(clone);
1058 }
1059
1060 // Tile the producer.
1061 int64_t resultNumber =
1062 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
1063 LDBG() << "resultNumber: " << resultNumber;
1064
1065 SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
1066 SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
1067
1068 FailureOr<TilingResult> tileAndFuseResult =
1069 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
1070 sizes);
1071
1072 if (failed(tileAndFuseResult)) {
1073 diag.attachNote(tileableProducer->getLoc())
1074 << "failed to tile producer op: " << *tileableProducer;
1075 return {};
1076 }
1077
1078#ifndef NDEBUG
1079 for (auto *tiledOp : tileAndFuseResult->tiledOps) {
1080 LDBG() << "tiledProducer: " << *tiledOp;
1081 }
1082#endif
1083
1084 // Replace the extract op.
1085 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1086 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1087 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1088 if (failed(maybeRankReduced)) {
1089 diag.attachNote(producerOp->getLoc())
1090 << "shape types don't match (missing canonicalization?):\nTiledOp: "
1091 << tileAndFuseResult->tiledValues[0]
1092 << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
1093 return {};
1094 }
1095 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
1096
1097 // Add new outputs to containing op, if required
1098 Operation *newContainingOp = replaceForAllWithNewSignature(
1099 rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
1100 resultNumber, offsets, sizes);
1101
1102 // Cleanup clone.
1103 if (isa<LoopLikeOpInterface>(containingOp))
1104 rewriter.eraseOp(tileableProducer);
1105
1106 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
1107}
1108
1109/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
1110/// it is exactly the `containingOp`, otherwise bail.
1111/// Then, find the first "extract" user of the tied block argument and tile it
1112/// right before its "extract" use. The tiled op is fused under the
1113/// `containingOp`.
1114/// Return this fused op on success or nullptr if anything fails.
1117 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
1118 Operation *containingOp) {
1119 LDBG() << "Try to fuse an extract use through block argument";
1120
1121 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
1122 if (!tileableProducer) {
1123 diag.attachNote(producerOp->getLoc())
1124 << "producer is not a TileableInterface: " << *producerOp;
1125 return {};
1126 }
1127
1128 // Search the first use by a "scf::ForallOp" user.
1129 scf::ForallOp forallOp;
1130 auto itProducerUses =
1131 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
1132 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
1133 return forallOp;
1134 });
1135 // If it's not from the containing op, return.
1136 if (!forallOp || forallOp != containingOp) {
1137 diag.attachNote(tileableProducer->getLoc())
1138 << "could not find a use by the containing op: " << *tileableProducer;
1139 return {};
1140 }
1141
1142 // Search the producer slices accessed within the containing
1143 // operation.
1144 // TODO: Generalize to more extract/insert/parallel_insert triples.
1145 // Maybe evolve into an interface.
1146 OpOperand *pUse = &(*itProducerUses);
1147 BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
1148
1149 // Search the producer slices accessed within the containing operation.
1150 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
1151 // evolve into an interface.
1152 auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
1153 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1154 return sliceOp && containingOp->isProperAncestor(sliceOp);
1155 });
1156
1157 // Find a fusion opportunity.
1158 if (itBBArgUsers == bbArg.getUsers().end()) {
1159 diag.attachNote(containingOp->getLoc())
1160 << "could not find fusion opportunity for bbArg: " << bbArg;
1161 return {};
1162 }
1163 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
1164
1165 // Try to fuse the producer in-place.
1166 OpBuilder::InsertionGuard guard(rewriter);
1167 rewriter.setInsertionPoint(sliceOpToTile);
1168
1169 // Replace the use in the tileableProducer before tiling: clone, replace and
1170 // then tile.
1171 int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
1172 LDBG() << "resultNumber: " << resultNumber;
1173
1174 // Gather destination tensors.
1175 SmallVector<Value> destinationTensors;
1177 rewriter, tileableProducer->getLoc(), tileableProducer,
1178 destinationTensors))) {
1179 diag.attachNote(tileableProducer->getLoc())
1180 << "failed to get destination tensors for: " << *tileableProducer;
1181 return {};
1182 }
1183
1184 IRMapping bvm;
1185 bvm.map(destinationTensors[resultNumber], bbArg);
1186 auto tileableProducerClone =
1187 cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
1188 llvm::scope_exit scopeGuard(
1189 [&]() { rewriter.eraseOp(tileableProducerClone); });
1190
1191 // Tile the producer.
1192 FailureOr<TilingResult> tileAndFuseResult =
1193 tileableProducerClone.generateResultTileValue(
1194 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
1195 sliceOpToTile.getMixedSizes());
1196 if (failed(tileAndFuseResult)) {
1197 diag.attachNote(tileableProducer->getLoc())
1198 << "failed to tile producer op: " << *tileableProducer;
1199 return {};
1200 }
1201
1202 // Replace the extract op.
1203 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1204 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1205 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1206 assert(succeeded(maybeRankReduced) && "unexpected shape");
1207 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
1208
1209 // Replace the use in containingOp.
1210 rewriter.modifyOpInPlace(containingOp, [&]() {
1211 containingOp->setOperand(pUse->getOperandNumber(),
1212 destinationTensors.front());
1213 });
1214
1215 return tileAndFuseResult->tiledOps;
1216}
1217
1219 Operation *producerOp,
1220 Operation *containingOp) {
1221 LDBG() << "Try to fuse an use by cloning";
1222
1223 // Gather all uses inside the containing op.
1225 for (OpResult result : producerOp->getOpResults()) {
1226 for (OpOperand &use : result.getUses()) {
1227 if (containingOp->isProperAncestor(use.getOwner())) {
1228 uses.push_back(&use);
1229 continue;
1230 }
1231 // Cannot clone and fuse if the use is by the containing op itself: fail
1232 // immediately.
1233 if (containingOp == use.getOwner()) {
1234 diag.attachNote(producerOp->getLoc())
1235 << "producer op use by containing op cannot be fused by cloning";
1236 return nullptr;
1237 }
1238 }
1239 }
1240
1241 // Check for a non-empty list of fusion opportunities.
1242 if (uses.empty()) {
1243 diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
1244 return nullptr;
1245 }
1246
1247 // Clone and fuse inside the containing op.
1248 Operation *fusedOp = nullptr;
1249 OpOperand *use = uses.front();
1250 // Parallel insert slice is not a valid clone destination.
1251 // TODO: Generalize to other type of ops.
1252 assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
1253 "Parallel insert slice is not a valid clone destination");
1254 unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
1255 LDBG() << "resultNumber: " << resultNumber;
1256
1257 OpBuilder::InsertionGuard guard(rewriter);
1258 rewriter.setInsertionPoint(use->getOwner());
1259 fusedOp = rewriter.clone(*producerOp);
1260 rewriter.modifyOpInPlace(
1261 use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1262
1263 return fusedOp;
1264}
1265
1266bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1267 // Allow repeated handles since we are fusing everything anyway.
1268 return true;
1269}
1270
1272transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
1275 SmallVector<Operation *> fusedOps;
1276 auto producerOps = state.getPayloadOps(getProducerOp());
1277 auto containingOps = state.getPayloadOps(getContainingOp());
1278 if (!llvm::hasSingleElement(containingOps)) {
1279 return emitDefiniteFailure()
1280 << "requires exactly one containing_op handle (got "
1281 << llvm::range_size(containingOps) << ")";
1282 }
1283 Operation *containingOp = *containingOps.begin();
1284
1285 // If nothing to fuse, propagate success.
1286 if (std::empty(producerOps)) {
1287 results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
1288 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1290 }
1291
1292 // Helper function to find the next producer that should be fused. Take any
1293 // producer that has a use inside the containing op.
1294 SetVector<Operation *> remainingProducers(llvm::from_range, producerOps);
1295 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1296 for (const auto &it : enumerate(remainingProducers)) {
1297 Operation *producerOp = it.value();
1298 // The containing op may be a user of producerOp: use isAncestor.
1299 int64_t numUsesInContainingOp =
1300 llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
1301 return containingOp->isAncestor(op);
1302 });
1303 // TODO: When resolving the TODO below (no duplicate ops), take an op
1304 // that has no use among the remaining producers. This is a topological
1305 // sorting.
1306 if (numUsesInContainingOp > 0) {
1307 if (numUsesInContainingOp == 1)
1308 remainingProducers.erase(remainingProducers.begin() + it.index());
1309 return producerOp;
1310 }
1311 }
1312 return failure();
1313 };
1314
1315 while (!remainingProducers.empty()) {
1316 auto nextProducer = getNextProducer();
1317 if (failed(nextProducer)) {
1318 auto diag = mlir::emitSilenceableFailure(getLoc())
1319 << "could not find next producer to fuse into container";
1320 diag.attachNote(containingOp->getLoc()) << "containing op";
1321 return diag;
1322 }
1323
1324 Operation *producerOp = *nextProducer;
1325
1326 // Default diagnostic, to be complemented with more failure information.
1328 diag << "could not fuse " << *producerOp << " into " << *containingOp;
1329
1330 // TODO: If there are multiple uses of the producer in the containing op,
1331 // we currently tile/clone the op multiple times (once per use). In some
1332 // cases, we can tile/clone once and reuse the value for each use.
1333 // Futhermore, producers should then be traversed according to a
1334 // topological sorting.
1335 auto [tiledOps, newContainingOp] =
1336 tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1337 if (!tiledOps.empty()) {
1338 LDBG() << "\nFused a direct extract use\n" << *containingOp;
1339 fusedOps.append(tiledOps);
1340 if (newContainingOp) {
1341 // Update handles associated with the containing op so we don't need to
1342 // invalidate them. This is a hack to support better composability
1343 // between tiling and fusion while a proper mechanism is being
1344 // investigated.
1345 //
1346 // DO NOT replicate this elsewhere unless you understand what you are
1347 // doing.
1348 LogicalResult replacementStatus =
1349 rewriter.notifyPayloadOperationReplaced(containingOp,
1350 newContainingOp);
1351 (void)replacementStatus;
1352 assert(succeeded(replacementStatus) &&
1353 "unable to update transform state mapping");
1354 rewriter.eraseOp(containingOp);
1355 containingOp = newContainingOp;
1356 }
1357 continue;
1358 }
1359
1360 SmallVector<Operation *> tiledContainingOpOperand =
1362 rewriter, diag, producerOp, containingOp);
1363 if (!tiledContainingOpOperand.empty()) {
1364 LDBG() << "\nFused an extract use through block argument\n"
1365 << *containingOp;
1366 fusedOps.append(tiledContainingOpOperand);
1367 continue;
1368 }
1369
1370 Operation *cloned =
1371 cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1372 if (cloned) {
1373 LDBG() << "\nFused an use by cloning\n" << *containingOp;
1374 fusedOps.push_back(cloned);
1375 continue;
1376 }
1378 }
1379
1380 results.set(cast<OpResult>(getFusedOp()), fusedOps);
1381 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1383}
1384
1385void transform::FuseIntoContainingOp::getEffects(
1387 consumesHandle(getProducerOpMutable(), effects);
1388 onlyReadsHandle(getContainingOpMutable(), effects);
1389 producesHandle(getOperation()->getOpResults(), effects);
1390 modifiesPayload(effects);
1391}
1392
1393//===----------------------------------------------------------------------===//
1394// GeneralizeOp
1395//===----------------------------------------------------------------------===//
1396
1398transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1399 LinalgOp target,
1402 // Exit early if no transformation is needed.
1403 if (isa<GenericOp>(target)) {
1404 results.push_back(target);
1406 }
1407 rewriter.setInsertionPoint(target);
1408 FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1409 if (succeeded(generic)) {
1410 results.push_back(generic->getOperation());
1412 }
1413 return emitDefaultSilenceableFailure(target);
1414}
1415
1416//===----------------------------------------------------------------------===//
1417// SpecializeOp
1418//===----------------------------------------------------------------------===/
1419
1421transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1422 LinalgOp target,
1425 // Exit early if the operation is not a generic.
1426 if (!isa<GenericOp>(target)) {
1427 results.push_back(target);
1429 }
1430 rewriter.setInsertionPoint(target);
1432 opts.emitCategoryOps = getEmitCategory();
1433 FailureOr<LinalgOp> named =
1434 specializeGenericOp(rewriter, cast<GenericOp>(target), opts);
1435 if (succeeded(named)) {
1436 results.push_back(named->getOperation());
1438 }
1439 return emitDefaultSilenceableFailure(target);
1440}
1441
1442//===----------------------------------------------------------------------===//
1443// InterchangeOp
1444//===----------------------------------------------------------------------===//
1445
1447transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1448 GenericOp target,
1451 ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1452 // Exit early if no transformation is needed.
1453 if (interchangeVector.empty()) {
1454 results.push_back(target);
1456 }
1457
1458 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1459 if (interchangeVector.size() != numLoops) {
1460 return emitSilenceableError()
1461 << getIteratorInterchangeAttrName() << " has length ("
1462 << interchangeVector.size()
1463 << ") different from the number of loops in the target operation ("
1464 << numLoops << ")";
1465 }
1466 FailureOr<GenericOp> res = interchangeGenericOp(
1467 rewriter, target, SmallVector<unsigned>(interchangeVector));
1468 if (failed(res))
1469 return emitDefiniteFailure() << "failed to apply";
1470 results.push_back(res->getOperation());
1472}
1473
1474LogicalResult transform::InterchangeOp::verify() {
1475 ArrayRef<int64_t> permutation = getIteratorInterchange();
1476 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1477 if (!std::is_permutation(sequence.begin(), sequence.end(),
1478 permutation.begin(), permutation.end())) {
1479 return emitOpError()
1480 << "expects iterator_interchange to be a permutation, found "
1481 << getIteratorInterchange();
1482 }
1483 return success();
1484}
1485
1486//===----------------------------------------------------------------------===//
1487// LinalgCopyToMemrefOp
1488//===----------------------------------------------------------------------===//
1489
1490DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1491 transform::TransformRewriter &rewriter, Operation *targetOp,
1494
1495 // Check if the target can be converted.
1496 if (!isa<linalg::CopyOp>(targetOp)) {
1498 emitSilenceableError() << "only linalg.copy target ops are supported";
1499 diag.attachNote(targetOp->getLoc()) << "target op";
1500 return diag;
1501 }
1502
1503 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1504 if (!copyOp.hasPureBufferSemantics()) {
1506 emitSilenceableError()
1507 << "cannot transform a linalg.copy on tensors into a memref.copy";
1508 diag.attachNote(targetOp->getLoc()) << "target op";
1509 return diag;
1510 }
1511
1512 SmallVector<Value> inputs = copyOp.getInputs();
1513 SmallVector<Value> outputs = copyOp.getOutputs();
1514 assert(inputs.size() == 1 && "expected linalg copy op with one input");
1515 assert(outputs.size() == 1 && "expected memref copy op with one output");
1516 Value input = inputs.front();
1517 Value output = outputs.front();
1518
1519 // linalg.copy supports different element types on source/dest whereas
1520 // memref.copy does not, so we must check that the source and dest types can
1521 // be handled by memref.copy and otherwise reject the transformation.
1522 if (!isa<ShapedType>(input.getType())) {
1524 emitSilenceableError()
1525 << "cannot transform a linalg.copy which input has no shape";
1526 diag.attachNote(targetOp->getLoc()) << "target op";
1527 return diag;
1528 }
1529
1530 // linalg.copy destination must be a shaped type.
1531 assert(isa<ShapedType>(output.getType()));
1532
1533 if (cast<ShapedType>(input.getType()).getElementType() !=
1534 cast<ShapedType>(output.getType()).getElementType()) {
1536 emitSilenceableError()
1537 << "cannot transform a linalg.copy with different source and "
1538 "destination element types ";
1539 diag.attachNote(targetOp->getLoc()) << "target op";
1540 return diag;
1541 }
1542
1543 // Target can be converted, do it.
1544 auto memrefCopyOp =
1545 rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1546
1547 results.push_back(memrefCopyOp);
1549}
1550
1551//===----------------------------------------------------------------------===//
1552// LowerPackOp
1553//===----------------------------------------------------------------------===//
1554
1555DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1556 transform::TransformRewriter &rewriter, linalg::PackOp target,
1557 transform::ApplyToEachResultList &transformResults,
1559 rewriter.setInsertionPoint(target);
1560 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1561 FailureOr<LowerPackResult> res =
1562 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1563 if (failed(res)) {
1564 return mlir::emitSilenceableFailure(target->getLoc())
1565 << "cannot lower to pad + expand + transpose";
1566 }
1567 transformResults.push_back(res->padOp);
1568 transformResults.push_back(res->expandShapeOp);
1569 transformResults.push_back(res->transposeOp);
1571}
1572
1573//===----------------------------------------------------------------------===//
1574// LowerUnPackOp
1575//===----------------------------------------------------------------------===//
1576
1577DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1578 transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1579 transform::ApplyToEachResultList &transformResults,
1581 rewriter.setInsertionPoint(target);
1582 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1583 FailureOr<LowerUnPackOpResult> res =
1584 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1585 if (failed(res)) {
1587 emitSilenceableError()
1588 << "cannot lower to transpose + collapse + extract";
1589 diag.attachNote(target->getLoc()) << "target payload op";
1590 return diag;
1591 }
1592 transformResults.push_back(res->emptyOp);
1593 transformResults.push_back(res->transposeOp);
1594 transformResults.push_back(res->collapseShapeOp);
1595 transformResults.push_back(res->extractSliceOp);
1596 transformResults.push_back(res->copyOp);
1598}
1599
1600//===---------------------------------------------------------------------===//
1601// MatchOp
1602//===---------------------------------------------------------------------===//
1603
1604void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1605 Value target, ArrayRef<StringRef> opNames) {
1606 result.addOperands(target);
1607 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1608 builder.getStrArrayAttr(opNames));
1609 result.addTypes(transform::AnyOpType::get(builder.getContext()));
1610}
1611
1612void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1613 TypeRange resultTypes, Value target,
1614 ArrayRef<StringRef> opNames) {
1615 result.addOperands(target);
1616 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1617 builder.getStrArrayAttr(opNames));
1618 result.addTypes(resultTypes);
1619}
1620
1622transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1625 llvm::StringSet<> strs;
1626 if (getOps().has_value())
1627 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1628
1629 auto payloadOps = state.getPayloadOps(getTarget());
1630 if (!llvm::hasSingleElement(payloadOps)) {
1631 return emitDefiniteFailure("requires exactly one target handle");
1632 }
1633
1635 bool incorrectNumOperandTypes = false;
1636 auto matchFun = [&](Operation *op) {
1637 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1638 return;
1639
1640 // Interfaces cannot be matched by name, just by ID.
1641 // So we specifically encode the interfaces we care about for this op.
1642 if (getInterface().has_value()) {
1643 auto iface = getInterface().value();
1644 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1645 !isa<LinalgOp>(op))
1646 return;
1647 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1648 !isa<TilingInterface>(op))
1649 return;
1650 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1651 !isa<LoopLikeOpInterface>(op))
1652 return;
1653 }
1654
1655 // Check if all specified attributes match.
1656 if (getOpAttrs().has_value()) {
1657 DictionaryAttr opAttrs = getOpAttrs().value();
1658 for (NamedAttribute attr : opAttrs) {
1659 if (attr.getName() == getInterfaceAttrName() ||
1660 attr.getName() == getOpsAttrName())
1661 continue;
1662 if (!op->hasAttr(attr.getName()))
1663 return;
1664 if (op->getAttr(attr.getName()) != attr.getValue())
1665 return;
1666 }
1667 }
1668
1669 if (getFilterResultType().has_value()) {
1670 Type t = getFilterResultType().value();
1671 if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1672 return;
1673 }
1674
1675 if (getFilterOperandTypes().has_value()) {
1676 mlir::ArrayAttr types = getFilterOperandTypes().value();
1677 auto operandTypes = op->getOperandTypes();
1678
1679 if (types.size() == 1) {
1680 // All the operands must must be equal to the specified type
1681 auto typeattr =
1682 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1683 Type t = cast<::mlir::Type>(typeattr.getValue());
1684 if (!llvm::all_of(op->getOperandTypes(),
1685 [&](Type operandType) { return operandType == t; }))
1686 return;
1687 } else {
1688 // The operand types must match all the types in the list (in the same
1689 // order in with they are specified)
1690 if (types.size() != operandTypes.size()) {
1691 incorrectNumOperandTypes = true;
1692 return;
1693 }
1694
1695 for (auto [attr, operandType] :
1696 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1697 auto typeattr = cast<mlir::TypeAttr>(attr);
1698 Type type = cast<::mlir::Type>(typeattr.getValue());
1699
1700 if (type != operandType)
1701 return;
1702 }
1703 }
1704 }
1705
1706 // All constraints are satisfied.
1707 res.push_back(op);
1708 return;
1709 };
1710
1711 (*payloadOps.begin())->walk(matchFun);
1712 if (incorrectNumOperandTypes)
1713 return emitDefiniteFailure("If filter_operand_types contains more than a "
1714 "type, then it must contain as much types as "
1715 "the number of operands in the target ops");
1716 results.set(cast<OpResult>(getResult()), res);
1718}
1719
1720//===---------------------------------------------------------------------===//
1721// MultiTileSizesOp
1722//===---------------------------------------------------------------------===//
1723
1725 Type targetType, Type lowSizeType, Type,
1726 Type) {
1727 printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1728}
1729
1730static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1731 Type &targetType, Type &lowSizeType,
1732 Type &highSizeType,
1733 Type &splitPointType) {
1734 FunctionType funcType;
1735 llvm::SMLoc typeLoc = parser.getCurrentLocation();
1736 if (failed(parser.parseType<FunctionType>(funcType)))
1737 return failure();
1738
1739 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1740 parser.emitError(typeLoc) << "expects a trailing functional type with one "
1741 "argument and one result";
1742 }
1743 targetType = funcType.getInput(0);
1744 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1745
1746 return success();
1747}
1748
1749DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1750 transform::TransformRewriter &rewriter, LinalgOp target,
1752 if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1753 if (target.hasDynamicShape()) {
1754 auto diag = emitSilenceableError()
1755 << "cannot compute parametric tile sizes for dynamically "
1756 "shaped payload op";
1757 diag.attachNote(target->getLoc()) << "payload op";
1758 return diag;
1759 }
1760
1761 FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1762 target, getDimension(), getTargetSize(), getDivisor());
1763 if (failed(spec)) {
1764 return emitSilenceableError()
1765 << "failed to compute multi-size tiling sizes";
1766 }
1767
1768 Builder builder(target.getContext());
1769 results.assign(llvm::map_range(
1770 ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1771 spec->lowTileSize * spec->lowTripCount}),
1772 [&builder, this](int64_t value) {
1773 return builder.getIntegerAttr(
1774 cast<ParamType>(getLowSize().getType()).getType(), value);
1775 }));
1777 }
1778
1779 OpBuilder builder(target.getContext());
1780 builder.setInsertionPoint(target);
1781 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1782 OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1783 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1784 builder, target, getDimension(), targetSize, divisor);
1785 if (failed(spec)) {
1786 return emitSilenceableError() << "could not generate tile size computation";
1787 }
1788
1789 AffineExpr s0 = builder.getAffineSymbolExpr(0);
1790 AffineExpr s1 = builder.getAffineSymbolExpr(1);
1791 Operation *splitPoint =
1792 affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1793 {spec->lowTileSize, spec->lowTripCount});
1794 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1795 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1796 assert(lowTileSize && highTileSize && splitPoint &&
1797 "tile sizes are not produced by operations");
1798 results.reserve(results.size() + 3);
1799 results.push_back(lowTileSize);
1800 results.push_back(highTileSize);
1801 results.push_back(splitPoint);
1803}
1804
1805void transform::MultiTileSizesOp::getEffects(
1807 onlyReadsHandle(getTargetMutable(), effects);
1808 producesHandle(getOperation()->getOpResults(), effects);
1809 if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1810 onlyReadsPayload(effects);
1811 else
1812 modifiesPayload(effects);
1813}
1814
1815LogicalResult transform::MultiTileSizesOp::verify() {
1816 if (getLowSize().getType() != getHighSize().getType() ||
1817 getLowSize().getType() != getSplitPoint().getType()) {
1818 return emitOpError() << "expects all results type to be the same";
1819 }
1820 return success();
1821}
1822
1823//===---------------------------------------------------------------------===//
1824// PackOp
1825//===---------------------------------------------------------------------===//
1826
1827void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1828 Value target,
1829 ArrayRef<OpFoldResult> mixedPackedSizes) {
1830 SmallVector<int64_t> staticPackedSizes;
1831 SmallVector<Value> dynamicPackedSizes;
1832 dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1833 staticPackedSizes);
1834 // Call the default builder which sets up the proper operands segment sizes
1835 // attributes for multiple variadic operands. In the absence of this, horrible
1836 // bugs ensue.
1837 Type linalgOpHType = transform::OperationType::get(
1838 builder.getContext(), GenericOp::getOperationName());
1839 build(builder, result,
1840 /*resultType=*/linalgOpHType,
1841 /*target=*/target,
1842 /*dynamic_sizes=*/dynamicPackedSizes,
1843 /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1844}
1845
1846SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1847 Builder b(getContext());
1848 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1849}
1850
1852transform::PackOp::apply(transform::TransformRewriter &rewriter,
1853 transform::TransformResults &transformResults,
1855 auto targetOps = state.getPayloadOps(getTarget());
1856 // If nothing to pack, propagate success.
1857 if (std::empty(targetOps)) {
1858 transformResults.set(cast<OpResult>(getPackedOp()),
1861 }
1862 // Fail on multi-op handles.
1863 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1864 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1865 return emitSilenceableError()
1866 << "requires target to map to exactly 1 LinalgOp (got "
1867 << llvm::range_size(targetOps) << ")";
1868 }
1869 // Fail on mismatched number of pack sizes.
1870 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1871 return emitSilenceableError()
1872 << "requires number of packed sizes match the number of loops ("
1873 << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1874 << ")";
1875 }
1876
1877 // Unpack handles to constants or actual SSA index values.
1878 SmallVector<OpFoldResult> packedSizes;
1880 state, *this, packedSizes, getMixedPackedSizes());
1881
1882 rewriter.setInsertionPoint(linalgOp);
1883 FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1884 if (failed(maybeResult))
1885 return emitDefiniteFailure("data tiling failed");
1886
1887 transformResults.set(cast<OpResult>(getPackedOp()),
1888 {maybeResult->packedLinalgOp.getOperation()});
1890}
1891
1892void transform::PackOp::getEffects(
1894 transform::consumesHandle(getTargetMutable(), effects);
1895 transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1896 transform::producesHandle(getOperation()->getOpResults(), effects);
1898}
1899
1900//===---------------------------------------------------------------------===//
1901// PackGreedilyOp.
1902//===---------------------------------------------------------------------===//
1903
1904LogicalResult transform::PackGreedilyOp::verify() {
1905 if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1906 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1907 << " is not a valid permutation";
1908 }
1909 // TODO: relax to allow empty once we have another strategy than just matmul.
1910 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1911 for (auto [s, nmo] :
1912 llvm::zip_equal(getMixedMatmulPackedSizes(),
1913 getMatmulPaddedSizesNextMultipleOf())) {
1914 std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1915 if (nmo != 0 &&
1916 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1917 return emitOpError() << "at most one of the packed_size and the "
1918 "padded_sizes_next_multiple_of can be nonzero "
1919 "for the matmul strategy";
1920 }
1921 }
1922 }
1923 return success();
1924}
1925
1927PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1928 transform::TransformResults &transformResults,
1931 for (Operation *op : state.getPayloadOps(getTarget())) {
1932 auto linalgOp = dyn_cast<LinalgOp>(op);
1933 if (!linalgOp)
1934 continue;
1935 // linalgOp will be replaced and the insertion point may be invalidated if
1936 // we set it before -> set it after.
1937 rewriter.setInsertionPointAfter(linalgOp);
1938 // Failing to pack greedily is perfectly fine.
1939 // In the future we will want to order packings according to some metric.
1940 FailureOr<PackResult> packResult = packMatmulGreedily(
1941 /*rewriter=*/rewriter,
1942 /*linalgOp=*/linalgOp,
1943 /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1944 /*mnkPaddedSizesNextMultipleOf=*/
1945 getMatmulPaddedSizesNextMultipleOf(),
1946 /*mnkOrder=*/getMatmulInnerDimsOrder());
1947 if (succeeded(packResult)) {
1948 results.push_back(packResult->packedLinalgOp);
1949 continue;
1950 }
1951 results.push_back(linalgOp);
1952 }
1953 transformResults.set(cast<OpResult>(getPackedOp()), results);
1955}
1956
1957SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1958 Builder b(getContext());
1959 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1960 b);
1961}
1962
1963void transform::PackGreedilyOp::getEffects(
1965 transform::consumesHandle(getTargetMutable(), effects);
1966 transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1967 transform::producesHandle(getOperation()->getOpResults(), effects);
1969}
1970
1971//===---------------------------------------------------------------------===//
1972// PackTransposeOp
1973//===---------------------------------------------------------------------===//
1974
1975LogicalResult transform::PackTransposeOp::verify() {
1976 if (!isPermutationVector(getInnerPerm())) {
1977 return emitOpError() << getInnerPermAttrName()
1978 << " is not a valid permutation";
1979 }
1980 if (!isPermutationVector(getOuterPerm())) {
1981 return emitOpError() << getOuterPermAttrName()
1982 << " is not a valid permutation";
1983 }
1984 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1985 return emitOpError() << " at least one of " << getInnerPermAttrName()
1986 << " or " << getOuterPermAttrName()
1987 << " must be specified";
1988 }
1989 return success();
1990}
1991
1992namespace {
1993enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1994} // namespace
1995
1996/// Return true if `permutation` is a valid permutation of the
1997/// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1998/// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1999/// This is the case when the `permutation` rank matches the rank expected by
2000/// `op` and `permutation` is itself a permutation vector.
2001/// Return true if either `op` or `permutation` are empty to allow a simpler
2002/// polymorphic implementation.
2003template <typename RelayoutOpTy>
2004static bool isValidPackingPermutation(
2005 RelayoutOpTy op, ArrayRef<int64_t> permutation,
2006 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
2007 static_assert(
2008 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
2009 "applies to only pack or unpack operations");
2010 if (!op || permutation.empty())
2011 return true;
2012 size_t innerRank = op.getInnerDimsPos().size();
2013 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
2014 return permutation.size() == innerRank && isPermutationVector(permutation);
2015 // op.getOuterDimsPerm() may be empty, in which case it is identity.
2016 // Don't rely on it.
2017 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
2018 return permutation.size() == op.getSourceRank() &&
2019 isPermutationVector(permutation);
2020 }
2021 return permutation.size() == op.getDestRank() &&
2022 isPermutationVector(permutation);
2023}
2024
2026transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
2027 transform::TransformResults &transformResults,
2029 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
2030 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
2031 // Step 1. If nothing to pack, propagate success.
2032 if (std::empty(packOrUnpackOps)) {
2033 transformResults.set(cast<OpResult>(getPackedOp()), {});
2034 transformResults.set(cast<OpResult>(getPackOp()), {});
2035 transformResults.set(cast<OpResult>(getUnPackOp()), {});
2037 }
2038
2039 // Step 2. Bunch of runtime sanity check and error messages.
2040 // Step 2.1. Fail on multi-op handles.
2041 if (!llvm::hasSingleElement(packOrUnpackOps) ||
2042 !llvm::hasSingleElement(linalgOps)) {
2043 return emitSilenceableError()
2044 << "requires target to map to exactly 1 "
2045 "packing op and 1 packed op ("
2046 << "got " << llvm::range_size(packOrUnpackOps) << " and "
2047 << llvm::range_size(linalgOps) << ")";
2048 }
2049
2050 // Step 2.2. Fail on wrong type.
2051 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
2052 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
2053 if ((!packOp && !unPackOp)) {
2054 return emitSilenceableError() << "requires target to map to a "
2055 "linalg.pack or linalg.unpack";
2056 }
2057 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
2058 if (!linalgOpTarget)
2059 return emitSilenceableError() << "requires a LinalgOp target";
2060
2061 // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
2062 LinalgOp linalgOp;
2063 if (packOp && packOp.getResult().hasOneUse())
2064 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
2065 else if (unPackOp)
2066 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
2067 if (linalgOp != linalgOpTarget) {
2068 auto errorMsg =
2069 packOp ? StringLiteral{"not a single use by the LinalgOp target"}
2070 : StringLiteral{"not produced by the LinalgOp target"};
2071 return emitSilenceableError() << errorMsg;
2072 }
2073
2074 // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
2075 // PackOp.
2076 if (unPackOp) {
2077 assert(!packOp && "packOp must be null on entry when unPackOp is not null");
2078 OpOperand *packUse = linalgOp.getDpsInitOperand(
2079 cast<OpResult>(unPackOp.getSource()).getResultNumber());
2080 packOp = packUse->get().getDefiningOp<linalg::PackOp>();
2081 if (!packOp || !packOp.getResult().hasOneUse())
2082 return emitSilenceableError() << "could not find matching pack op";
2083 }
2084
2085 // Step 2.5. Fail if any permutation does not validate.
2086 for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
2087 ArrayRef<int64_t> perm =
2088 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
2089 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
2090 ? StringLiteral{"invalid outer_perm"}
2091 : StringLiteral{"invalid inner_perm"};
2092 if (!isValidPackingPermutation(packOp, perm, permType) ||
2093 !isValidPackingPermutation(unPackOp, perm, permType)) {
2094 Operation *packOrUnpackOp =
2095 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
2096 return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
2097 }
2098 }
2099
2100 // From here on, packOp and linalgOp are always present, unPackOp may or may
2101 // not be present.
2102 assert(packOp && linalgOp && "unexpected null op");
2103
2104 // Step 3. Actually transpose the ops.
2105 FailureOr<PackTransposeResult> res = packTranspose(
2106 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
2107 // Preconditions have been checked, it is an error to fail here.
2108 assert(succeeded(res) && "unexpected packTranspose failure");
2109
2110 // Step 4. Return results.
2111 transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
2112 transformResults.set(cast<OpResult>(getPackedOp()),
2113 {res->transposedLinalgOp});
2114 if (unPackOp) {
2115 transformResults.set(cast<OpResult>(getUnPackOp()),
2116 {res->transposedUnPackOp});
2117 } else {
2118 transformResults.set(cast<OpResult>(getUnPackOp()), {});
2119 }
2120
2122}
2123
2124//===---------------------------------------------------------------------===//
2125// PadOp
2126//===---------------------------------------------------------------------===//
2127
2128void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
2129 ArrayRef<int64_t> paddingDimensions,
2130 ArrayRef<int64_t> padToMultipleOf,
2131 ArrayRef<int64_t> nofoldFlags,
2132 ArrayRef<Attribute> transposePaddings,
2133 StringRef copyBackOp,
2134 bool usePrescribedTensorShapes) {
2135 auto resultType = transform::AnyOpType::get(b.getContext());
2136 return build(/*odsBuilder=*/b,
2137 /*result=*/result,
2138 /*types=*/TypeRange{resultType, resultType},
2139 /*target=*/target,
2140 /*padding_values=*/ArrayAttr(), // let inference handle this
2141 /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
2142 /*pad_to_multiple_of=*/ValueRange{},
2143 /*padToMultipleOf=*/
2144 (padToMultipleOf.empty()
2146 : b.getDenseI64ArrayAttr(padToMultipleOf)),
2147 /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
2148 /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
2149 /*copy_back_op=*/b.getStringAttr(copyBackOp),
2150 /*use_prescribed_tensor_shapes=*/
2151 usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
2152}
2153
2154void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
2155 ArrayRef<int64_t> paddingDimensions,
2156 ArrayRef<OpFoldResult> mixedPadToMultipleOf,
2157 ArrayRef<int64_t> nofoldFlags,
2158 ArrayRef<Attribute> transposePaddings,
2159 StringRef copyBackOp,
2160 bool usePrescribedTensorShapes) {
2161 auto resultType = transform::AnyOpType::get(b.getContext());
2162 SmallVector<int64_t> staticPadToMultipleOf;
2163 SmallVector<Value> dynamicPadToMultipleOf;
2164 dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
2165 staticPadToMultipleOf);
2166 return build(/*odsBuilder=*/b,
2167 /*result=*/result,
2168 /*types=*/TypeRange{resultType, resultType},
2169 /*target=*/target,
2170 /*padding_values=*/ArrayAttr(), // let inference handle this
2171 /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
2172 /*pad_to_multiple_of=*/dynamicPadToMultipleOf,
2173 /*padToMultipleOf=*/staticPadToMultipleOf,
2174 /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
2175 /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
2176 /*copy_back_op=*/copyBackOp,
2177 /*use_prescribed_tensor_shapes=*/usePrescribedTensorShapes);
2178}
2179
2180void PadOp::getEffects(
2182 consumesHandle(getTargetMutable(), effects);
2183 onlyReadsHandle(getPadToMultipleOfMutable(), effects);
2184 producesHandle(getOperation()->getOpResults(), effects);
2185 modifiesPayload(effects);
2186}
2187
2188SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
2189 Builder b(getContext());
2190 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
2191}
2192
2193DiagnosedSilenceableFailure
2194transform::PadOp::apply(transform::TransformRewriter &rewriter,
2195 transform::TransformResults &results,
2196 transform::TransformState &state) {
2197 auto transformOp = cast<TransformOpInterface>(getOperation());
2198 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
2199
2200 for (Operation *target : state.getPayloadOps(getTarget())) {
2201 auto linalgTarget = dyn_cast<LinalgOp>(target);
2202 if (!linalgTarget) {
2203 auto diag = emitSilenceableError() << "expected LinalgOp target";
2204 diag.attachNote(target->getLoc()) << "target op";
2205 return diag;
2206 }
2207
2208 // Convert the integer packing flags to booleans.
2209 SmallVector<bool> nofoldFlags;
2210 for (int64_t packPadding :
2211 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
2212 nofoldFlags.push_back(static_cast<bool>(packPadding));
2213
2214 // Convert the padding values to attributes.
2215 SmallVector<Attribute> paddingValues;
2216 for (auto const &[untypedAttr, elementOrTensorType] :
2217 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
2218
2219 if (matchPattern(untypedAttr, ub::m_Poison())) {
2220 paddingValues.push_back(untypedAttr);
2221 continue;
2222 }
2223 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2224 if (!attr) {
2225 emitOpError("expects padding values to be typed attributes or poison");
2227 }
2228 Type elementType = getElementTypeOrSelf(elementOrTensorType);
2229 // Try to parse string attributes to obtain an attribute of element type.
2230 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2231 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2232 stringAttr, getContext(), elementType,
2233 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2234 if (!parsedAttr || parsedAttr.getType() != elementType) {
2235 auto diag = this->emitOpError("expects a padding that parses to ")
2236 << elementType << ", got " << untypedAttr;
2237 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2239 }
2240 paddingValues.push_back(parsedAttr);
2241 continue;
2242 }
2243 // Otherwise, add the attribute directly.
2244 if (attr.getType() != elementType) {
2245 auto diag = this->emitOpError("expects a padding value of type ")
2246 << elementType << ", got " << attr;
2247 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2249 }
2250 paddingValues.push_back(attr);
2251 }
2252
2253 // Extract the transpose vectors.
2254 SmallVector<SmallVector<int64_t>> transposePaddings;
2255 for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2256 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2257 cast<ArrayAttr>(transposeVector)));
2258
2259 LinalgOp paddedOp;
2260 LinalgPaddingOptions options;
2261 options.paddingDimensions =
2262 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2263
2264 SmallVector<int64_t> padToMultipleOf;
2265 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
2266 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2267 if (!status.succeeded())
2268 return status;
2269 if (padToMultipleOf.empty())
2270 padToMultipleOf =
2271 SmallVector<int64_t>(options.paddingDimensions.size(), 1);
2272
2273 options.padToMultipleOf = padToMultipleOf;
2274 options.paddingValues = paddingValues;
2275 options.nofoldFlags = nofoldFlags;
2276 if (getCopyBackOp() ==
2277 bufferization::MaterializeInDestinationOp::getOperationName()) {
2278 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
2279 BufferizationMaterializeInDestination;
2280 } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2281 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
2282 } else if (getCopyBackOp() == kCopyOpNone) {
2283 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
2284 } else {
2285 llvm_unreachable("unsupported copy_back op");
2286 }
2287 // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2288 bool irChanged = false;
2289 if (getUsePrescribedTensorShapes() &&
2290 linalgTarget.hasPureTensorSemantics()) {
2291 OpBuilder::InsertionGuard g(rewriter);
2292 rewriter.setInsertionPoint(linalgTarget);
2293 for (OpOperand &operand : linalgTarget->getOpOperands()) {
2294 for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2295 if (ShapedType::isStatic(dim))
2296 continue;
2297 options.setSizeToPadTo(operand.getOperandNumber(), i,
2298 tensor::getMixedSize(rewriter,
2299 operand.get().getLoc(),
2300 operand.get(), i));
2301 irChanged = true;
2302 }
2303 }
2304 }
2305
2306 SmallVector<Value> replacements;
2307 SmallVector<tensor::PadOp> newPadOps;
2308 if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
2309 replacements, newPadOps))) {
2310 if (irChanged) {
2311 auto diag = emitDefiniteFailure() << "failed to pad op";
2312 diag.attachNote(target->getLoc()) << "target op";
2313 return diag;
2314 }
2315 auto diag = emitSilenceableError() << "failed to pad op";
2316 diag.attachNote(target->getLoc()) << "target op";
2317 return diag;
2318 }
2319
2320 // We need to perform our own replacement here because this API is still
2321 // used in patterns that "pad and hoist", for which the replacement values
2322 // need to be different.
2323 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2324 // that we have more composable abstractions.
2325 rewriter.replaceOp(linalgTarget, replacements);
2326 paddedOps.push_back(paddedOp);
2327 padOps.append(newPadOps.begin(), newPadOps.end());
2328 if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2329 for (Value v : replacements) {
2330 Operation *copyBackOp = v.getDefiningOp();
2331 if (!llvm::is_contained(copyBackOps, copyBackOp))
2332 copyBackOps.push_back(copyBackOp);
2333 }
2334 }
2335 }
2336
2337 results.set(cast<OpResult>(getPadded()), paddedOps);
2338 results.set(cast<OpResult>(getPad()), padOps);
2339 results.set(cast<OpResult>(getCopy()), copyBackOps);
2341}
2342
2343LogicalResult transform::PadOp::verify() {
2344 SmallVector<int64_t> nofoldFlags =
2345 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2346 if (any_of(nofoldFlags, [](int64_t packPadding) {
2347 return packPadding != 0 && packPadding != 1;
2348 })) {
2349 return emitOpError()
2350 << "expects nofold_flags to contain booleans (0/1), found "
2351 << getNofoldFlags();
2352 }
2353
2354 SmallVector<int64_t> paddingDimensions =
2355 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2356 if (any_of(paddingDimensions,
2357 [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2358 return emitOpError() << "expects padding_dimensions to contain positive "
2359 "integers, found "
2360 << getPaddingDimensions();
2361 }
2362 if (!getMixedPadToMultipleOf().empty()) {
2363 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2364 return emitOpError() << "expects as many multiples as padding_dimensions";
2365 }
2366 }
2367 ArrayAttr transposes = getTransposePaddings();
2368 for (Attribute attr : transposes) {
2369 SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2370 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2371 if (!std::is_permutation(sequence.begin(), sequence.end(),
2372 transpose.begin(), transpose.end())) {
2373 return emitOpError()
2374 << "expects transpose_paddings to be a permutation, found "
2375 << attr;
2376 }
2377 }
2378 if (getCopyBackOp() !=
2379 bufferization::MaterializeInDestinationOp::getOperationName() &&
2380 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2381 getCopyBackOp() != kCopyOpNone)
2382 return emitOpError() << "invalid copy_back_op";
2383 return success();
2384}
2385
2386//===---------------------------------------------------------------------===//
2387// PadTilingInterfaceOp
2388//===---------------------------------------------------------------------===//
2389
2390void transform::PadTilingInterfaceOp::build(OpBuilder &b,
2391 OperationState &result,
2392 Value target,
2393 ArrayRef<int64_t> paddingSizes,
2394 bool padToMultipleOf) {
2395 auto resultType = transform::AnyOpType::get(b.getContext());
2396 return build(/*odsBuilder=*/b,
2397 /*result=*/result,
2398 /*types=*/TypeRange{resultType, resultType},
2399 /*target=*/target,
2400 /*padding_values=*/ArrayAttr(), // let inference handle this
2401 /*padding_sizes=*/ValueRange{},
2402 /*paddingSizes=*/
2403 (paddingSizes.empty() ? DenseI64ArrayAttr()
2404 : b.getDenseI64ArrayAttr(paddingSizes)),
2405 /*pad_to_multiple_of=*/
2406 padToMultipleOf ? b.getUnitAttr() : nullptr);
2407}
2408
2409void transform::PadTilingInterfaceOp::build(
2410 OpBuilder &b, OperationState &result, Value target,
2411 ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
2412 auto resultType = transform::AnyOpType::get(b.getContext());
2413 SmallVector<int64_t> staticPaddingSizes;
2414 SmallVector<Value> dynamicPaddingSizes;
2415 dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
2416 staticPaddingSizes);
2417 return build(/*odsBuilder=*/b,
2418 /*result=*/result,
2419 /*types=*/TypeRange{resultType, resultType},
2420 /*target=*/target,
2421 /*padding_values=*/ArrayAttr(), // let inference handle this
2422 /*padding_sizes=*/dynamicPaddingSizes,
2423 /*paddingSizes=*/staticPaddingSizes,
2424 /*usePrescribedTensorShapes=*/padToMultipleOf);
2425}
2426
2427void transform::PadTilingInterfaceOp::getEffects(
2428 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2429 consumesHandle(getTargetMutable(), effects);
2430 onlyReadsHandle(getPaddingSizesMutable(), effects);
2431 producesHandle(getOperation()->getOpResults(), effects);
2432 modifiesPayload(effects);
2433}
2434
2435SmallVector<OpFoldResult>
2436transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2437 Builder b(getContext());
2438 return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2439}
2440
2441DiagnosedSilenceableFailure
2442transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2443 transform::TransformResults &results,
2444 transform::TransformState &state) {
2445 SmallVector<Operation *> paddedOps, padOps;
2446
2447 for (Operation *target : state.getPayloadOps(getTarget())) {
2448 auto targetOp = dyn_cast<TilingInterface>(target);
2449 if (!targetOp) {
2450 auto diag = emitSilenceableError() << "expected TilingInterface target";
2451 diag.attachNote(target->getLoc()) << "target op";
2452 return diag;
2453 }
2454
2455 // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
2456 // loopsToOperand map / C++ APIs to compute the effect of padding on
2457 // operands.
2458 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2459 auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
2460 "supported atm";
2461 diag.attachNote(target->getLoc()) << "target op";
2462 return diag;
2463 }
2464
2465 // Convert the padding values to attributes.
2466 SmallVector<Attribute> paddingValues;
2467 for (auto const &[untypedAttr, elementOrTensorType] :
2468 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2469 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2470 Type elementType = getElementTypeOrSelf(elementOrTensorType);
2471
2472 if (matchPattern(untypedAttr, ub::m_Poison())) {
2473 paddingValues.push_back(untypedAttr);
2474 continue;
2475 }
2476 if (!attr) {
2477 emitOpError("expects padding values to be typed attributes or poison");
2479 }
2480 // Try to parse string attributes to obtain an attribute of element type.
2481 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2482 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2483 stringAttr, getContext(), elementType,
2484 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2485 if (!parsedAttr || parsedAttr.getType() != elementType) {
2486 auto diag = this->emitOpError("expects a padding that parses to ")
2487 << elementType << ", got " << attr;
2488 diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2490 }
2491 paddingValues.push_back(parsedAttr);
2492 continue;
2493 }
2494 // Otherwise, add the attribute directly.
2495 if (attr.getType() != elementType) {
2496 auto diag = this->emitOpError("expects a padding value of type ")
2497 << elementType << ", got " << attr;
2498 diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2500 }
2501 paddingValues.push_back(attr);
2502 }
2503
2504 // Set options.
2505 PadTilingInterfaceOptions options;
2506 options.setPaddingValues(paddingValues)
2507 .setPaddingSizes(getMixedPaddingSizes())
2508 .setPadToMultipleOf(getPadToMultipleOf());
2509
2510 OpBuilder::InsertionGuard g(rewriter);
2511 rewriter.setInsertionPointAfter(targetOp);
2512 auto maybePadOps = rewriteAsPaddedOp(
2513 rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
2514 if (failed(maybePadOps)) {
2515 auto diag = emitSilenceableError() << "failed to pad op";
2516 diag.attachNote(target->getLoc()) << "target op";
2517 return diag;
2518 }
2519 const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
2520
2521 // Set transform results.
2522 paddedOps.push_back(paddedOp);
2523 padOps.append(paddedOperands.begin(), paddedOperands.end());
2524 rewriter.replaceOp(targetOp.getOperation(), slicedResults);
2525 }
2526
2527 results.set(cast<OpResult>(getPadded()), paddedOps);
2528 results.set(cast<OpResult>(getPad()), padOps);
2530}
2531
2532LogicalResult transform::PadTilingInterfaceOp::verify() { return success(); }
2533
2534//===---------------------------------------------------------------------===//
2535// HoistPadOp
2536//===---------------------------------------------------------------------===//
2537
2538DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2539 transform::TransformRewriter &rewriter,
2540 transform::TransformResults &transformResults,
2541 transform::TransformState &state) {
2542 auto targetOps = state.getPayloadOps(getTarget());
2543 auto loopOps = state.getPayloadOps(getLoop());
2544 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2545 return emitDefiniteFailure()
2546 << "requires exactly one target and one loop handle (got "
2547 << llvm::range_size(targetOps) << " and "
2548 << llvm::range_size(loopOps) << ")";
2549 }
2550
2551 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2552 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2553 if (!padOp || !loopOp)
2554 return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2555
2556 FailureOr<linalg::detail::PackingResult> result =
2557 linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2558 getTranspose());
2559 if (failed(result))
2560 return emitDefiniteFailure() << "could not build packing loop nest";
2561
2562 if (result->clonedLoopIvs.empty()) {
2563 transformResults.set(cast<OpResult>(getPackingLoop()),
2564 {result->hoistedPadOp.getOperation()});
2566 }
2567 auto outerPackedLoop =
2568 scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2569 transformResults.set(cast<OpResult>(getPackingLoop()),
2570 {outerPackedLoop.getOperation()});
2572}
2573
2574LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
2575 ArrayRef<int64_t> transpose = getTranspose();
2576 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2577 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2578 transpose.end())) {
2579 return emitOpError() << "expects transpose to be a permutation, found "
2580 << getTranspose();
2581 }
2582 return success();
2583}
2584
2585void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2586 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2587 transform::onlyReadsHandle(getTargetMutable(), effects);
2588 transform::onlyReadsHandle(getLoopMutable(), effects);
2589 transform::producesHandle(getOperation()->getOpResults(), effects);
2591}
2592
2593DiagnosedSilenceableFailure
2594transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2595 tensor::PadOp target,
2596 transform::ApplyToEachResultList &results,
2597 transform::TransformState &state) {
2598 tensor::PadOp hoistedPadOp;
2599 SmallVector<TransposeOp> transposeOps;
2600 FailureOr<Value> result =
2601 hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2602 hoistedPadOp, transposeOps);
2603 if (succeeded(result)) {
2604 // We need to perform our own replacement here because this API is still
2605 // used in patterns that "pad and hoist", for which the replacement values
2606 // need to be different.
2607 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2608 // that we have more composable abstractions.
2609 rewriter.replaceOp(target, *result);
2610 results.push_back(hoistedPadOp);
2612 }
2613 return emitDefaultSilenceableFailure(target);
2614}
2615
2616LogicalResult transform::HoistPadOp::verify() {
2617 ArrayRef<int64_t> transpose = getTranspose();
2618 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2619 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2620 transpose.end())) {
2621 return emitOpError() << "expects transpose to be a permutation, found "
2622 << getTranspose();
2623 }
2624 return success();
2625}
2626
2627//===----------------------------------------------------------------------===//
2628// PromoteOp
2629//===----------------------------------------------------------------------===//
2630
2631DiagnosedSilenceableFailure
2632transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2633 LinalgOp target,
2634 transform::ApplyToEachResultList &results,
2635 transform::TransformState &state) {
2636 LinalgPromotionOptions promotionOptions;
2637 if (!getOperandsToPromote().empty())
2638 promotionOptions = promotionOptions.setOperandsToPromote(
2639 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2640 if (getUseFullTilesByDefault())
2641 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2642 getUseFullTilesByDefault());
2643 if (getUseOriginalSubviewSize())
2644 promotionOptions =
2645 promotionOptions.setUseOriginalSubviewSize(getUseOriginalSubviewSize());
2646 if (getUseAlloca())
2647 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2648 if (!getUseFullTileBuffers().empty())
2649 promotionOptions = promotionOptions.setUseFullTileBuffers(
2650 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2651 if (getAlignment().has_value())
2652 promotionOptions = promotionOptions.setAlignment(*getAlignment());
2653 if (getMemorySpace().has_value())
2654 promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2655
2656 if (getMapping().has_value()) {
2657 // The mapping should only contain an element
2658 auto mapping = *getMapping();
2659 if (mapping.size() > 1)
2660 return emitDefaultDefiniteFailure(target);
2661
2662 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2663
2664 if (addressSpace.getAddressSpace() ==
2665 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2666 promotionOptions =
2667 promotionOptions
2671 .setUseFullTileBuffers({false, false});
2672 } else if (addressSpace.getAddressSpace() ==
2673 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2674 promotionOptions =
2675 promotionOptions
2679 .setUseFullTileBuffers({false, false});
2680 } else {
2681 return emitDefaultDefiniteFailure(target);
2682 }
2683 }
2684
2685 if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2686 return emitDefaultDefiniteFailure(target);
2687
2688 rewriter.setInsertionPoint(target);
2689 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2690 if (failed(res))
2691 return emitDefaultDefiniteFailure(target);
2692 results.push_back(target);
2694}
2695
2696//===----------------------------------------------------------------------===//
2697// ReplaceOp
2698//===----------------------------------------------------------------------===//
2699
2700DiagnosedSilenceableFailure
2701transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2702 TransformResults &transformResults,
2703 TransformState &state) {
2704 auto payload = state.getPayloadOps(getTarget());
2705
2706 // Check for invalid targets.
2707 for (Operation *target : payload) {
2708 if (target->getNumOperands() > 0)
2709 return emitDefiniteFailure() << "expected target without operands";
2710 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2711 target->getNumRegions() > 0)
2712 return emitDefiniteFailure()
2713 << "expected target that is isolated from above";
2714 }
2715
2716 // Clone and replace.
2717 Operation *pattern = &getBodyRegion().front().front();
2718 SmallVector<Operation *> replacements;
2719 for (Operation *target : payload) {
2720 if (getOperation()->isAncestor(target))
2721 continue;
2722 rewriter.setInsertionPoint(target);
2723 Operation *replacement = rewriter.clone(*pattern);
2724 rewriter.replaceOp(target, replacement->getResults());
2725 replacements.push_back(replacement);
2726 }
2727 transformResults.set(cast<OpResult>(getReplacement()), replacements);
2729}
2730
2731void transform::ReplaceOp::getEffects(
2732 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2733 consumesHandle(getTargetMutable(), effects);
2734 producesHandle(getOperation()->getOpResults(), effects);
2735 modifiesPayload(effects);
2736}
2737
2738LogicalResult transform::ReplaceOp::verify() {
2739 if (!getBodyRegion().hasOneBlock())
2740 return emitOpError() << "expected one block";
2741 if (std::distance(getBodyRegion().front().begin(),
2742 getBodyRegion().front().end()) != 1)
2743 return emitOpError() << "expected one operation in block";
2744 Operation *replacement = &getBodyRegion().front().front();
2745 if (replacement->getNumOperands() > 0)
2746 return replacement->emitOpError()
2747 << "expected replacement without operands";
2748 if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2749 replacement->getNumRegions() > 0)
2750 return replacement->emitOpError()
2751 << "expect op that is isolated from above";
2752 return success();
2753}
2754
2755//===----------------------------------------------------------------------===//
2756// ScalarizeOp
2757//===----------------------------------------------------------------------===//
2758
2759DiagnosedSilenceableFailure
2760transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2761 LinalgOp target,
2762 transform::ApplyToEachResultList &results,
2763 transform::TransformState &state) {
2764 scf::SCFTilingOptions tilingOptions;
2765 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2766 SmallVector<OpFoldResult> tileSizes;
2767 Location loc = target.getLoc();
2768 SmallVector<OpFoldResult> allShapeSizes =
2769 target.createFlatListOfOperandDims(b, loc);
2770 AffineMap map = target.getShapesToLoopsMap();
2771 if (!map)
2772 return tileSizes;
2773 SmallVector<OpFoldResult> shapeSizes =
2775 allShapeSizes);
2776 // If the shape size is dynamic, tile by 1.
2777 // Otherwise, do not tile (i.e. tile size 0).
2778 for (OpFoldResult shapeSize : shapeSizes) {
2779 tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2780 : b.getIndexAttr(1));
2781 }
2782 return tileSizes;
2783 });
2784 rewriter.setInsertionPoint(target);
2785 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2786 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2787 if (failed(maybeTilingResult))
2788 return emitDefaultDefiniteFailure(target);
2789
2790 if (target->getNumResults())
2791 rewriter.replaceOp(target, maybeTilingResult->replacements);
2792 else
2793 rewriter.eraseOp(target);
2794
2795 results.reserve(maybeTilingResult->tiledOps.size());
2796 for (Operation *tiled : maybeTilingResult->tiledOps)
2797 results.push_back(tiled);
2799}
2800
2801//===----------------------------------------------------------------------===//
2802// ConvertToLoopsOp
2803//===----------------------------------------------------------------------===//
2804
2805DiagnosedSilenceableFailure
2806transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2807 transform::TransformResults &results,
2808 transform::TransformState &state) {
2809 SmallVector<Operation *> loops;
2810 for (Operation *target : state.getPayloadOps(getTarget())) {
2811 auto tilingOp = dyn_cast<TilingInterface>(*target);
2812 if (!tilingOp) {
2813 DiagnosedSilenceableFailure diag =
2814 emitSilenceableError()
2815 << "expected the payload to implement TilingInterface";
2816 diag.attachNote(target->getLoc()) << "payload op";
2817 return diag;
2818 }
2819 rewriter.setInsertionPoint(target);
2820 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2821 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2822 if (failed(generatedLoops))
2823 return emitDefaultDefiniteFailure(target);
2824 for (scf::ForOp &loop : *generatedLoops) {
2825 loops.push_back(loop.getOperation());
2826 }
2827 rewriter.eraseOp(target);
2828 }
2829 results.set(cast<OpResult>(getResult()), loops);
2831}
2832
2833//===----------------------------------------------------------------------===//
2834// RewriteInDestinationPassingStyleOp
2835//===----------------------------------------------------------------------===//
2836
2837DiagnosedSilenceableFailure
2838transform::RewriteInDestinationPassingStyleOp::applyToOne(
2839 transform::TransformRewriter &rewriter, Operation *target,
2840 transform::ApplyToEachResultList &results,
2841 transform::TransformState &state) {
2842 rewriter.setInsertionPoint(target);
2843 FailureOr<Operation *> maybeResult =
2845 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2846 [&rewriter](auto op) {
2847 return rewriteInDestinationPassingStyle(rewriter, op);
2848 });
2849 if (failed(maybeResult))
2850 return emitDefaultSilenceableFailure(target);
2851 results.push_back(*maybeResult);
2853}
2854
2855//===----------------------------------------------------------------------===//
2856// SplitOp
2857//===----------------------------------------------------------------------===//
2858
2859DiagnosedSilenceableFailure
2860SplitOp::apply(transform::TransformRewriter &rewriter,
2861 TransformResults &results, TransformState &state) {
2862 // Collect the dynamic split points if provided.
2863 SmallVector<Operation *> payload =
2864 llvm::to_vector(state.getPayloadOps(getTarget()));
2865
2866 bool isMultiwaySplit = getMultiway();
2867
2868 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2869 return mlir::emitSilenceableFailure(getLoc())
2870 << "requires exactly one target when "
2871 "multiway split is enabled (got "
2872 << llvm::range_size(payload) << ")";
2873 }
2874
2875 SmallVector<OpFoldResult> chunkSizes;
2876
2877 if (!isMultiwaySplit)
2878 chunkSizes.reserve(payload.size());
2879
2880 if (getDynamicChunkSizes()) {
2882 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2883 chunkSizes = llvm::map_to_vector(
2884 state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2885 if (op->getNumResults() != 1 ||
2886 !op->getResult(0).getType().isIndex()) {
2887 diag = emitSilenceableError()
2888 << "expected dynamic split point handle to point to a "
2889 "single-result index-typed op";
2890 diag.attachNote(op->getLoc()) << "dynamic split point";
2891 }
2892 return OpFoldResult(op->getResult(0));
2893 });
2894 } else {
2895 chunkSizes = llvm::map_to_vector(
2896 state.getParams(getDynamicChunkSizes()),
2897 [](Attribute attr) { return OpFoldResult(attr); });
2898 }
2899 if (diag.isSilenceableFailure())
2900 return diag;
2901
2902 // For multiway split, a single payload is expected to have multiple
2903 // split points.
2904 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2905 return emitDefiniteFailure()
2906 << "expected the dynamic split point handle to point to as "
2907 "many operations ("
2908 << chunkSizes.size() << ") as the target handle ("
2909 << payload.size() << ")";
2910 }
2911 } else {
2912 chunkSizes.resize(payload.size(),
2913 rewriter.getIndexAttr(getStaticChunkSizes()));
2914 }
2915
2916 auto checkStructuredOpAndDimensions =
2917 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2918 if (!linalgOp) {
2919 auto diag = emitSilenceableError() << "only applies to structured ops";
2920 diag.attachNote(loc) << "target op";
2921 return diag;
2922 }
2923
2924 if (getDimension() >= linalgOp.getNumLoops()) {
2925 auto diag = emitSilenceableError() << "dimension " << getDimension()
2926 << " does not exist in target op";
2927 diag.attachNote(loc) << "target op";
2928 return diag;
2929 }
2931 };
2932
2933 auto checkFailureInSplitting =
2934 [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2935 if (hasFailed) {
2936 auto diag = emitDefiniteFailure() << "internal failure in splitting";
2937 diag.attachNote(loc) << "target op";
2938 return diag;
2939 }
2941 };
2942
2943 SmallVector<Operation *> opList;
2944 if (isMultiwaySplit) {
2945
2946 // Split a single target operation at multiple points.
2947 TilingInterface head, tail;
2948 Operation *target = payload.front();
2949
2950 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2951
2952 // Check that the target is a valid LinalgOp with correct dimensions.
2953 DiagnosedSilenceableFailure diag =
2954 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2955 if (diag.isSilenceableFailure())
2956 return diag;
2957
2958 for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2959
2960 if (idx > 0)
2961 target = tail.getOperation();
2962
2963 if (!target)
2964 break;
2965
2966 linalgOp = cast<LinalgOp>(target);
2967 Location loc = target->getLoc();
2968
2969 rewriter.setInsertionPoint(linalgOp);
2970 std::tie(head, tail) = linalg::splitOp(
2971 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2972 getDimension(), chunkSize);
2973
2974 // Propagate errors.
2975 DiagnosedSilenceableFailure diag =
2976 checkFailureInSplitting(!head && !tail, loc);
2977 if (diag.isDefiniteFailure())
2978 return diag;
2979
2980 opList.push_back(head.getOperation());
2981 }
2982
2983 // Append any leftover parts to the end of the result list.
2984 if (tail)
2985 opList.push_back(tail.getOperation());
2986
2987 } else {
2988 // Split each target operation.
2989 SmallVector<Operation *> first, second;
2990 Operation *noSecondPart = nullptr;
2991 for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2992 Operation *target = std::get<0>(pair);
2993 Location loc = target->getLoc();
2994 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2995 DiagnosedSilenceableFailure diag =
2996 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2997
2998 if (diag.isSilenceableFailure())
2999 return diag;
3000
3001 rewriter.setInsertionPoint(linalgOp);
3002 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
3003 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
3004 getDimension(), std::get<1>(pair));
3005
3006 // Propagate errors.
3007 DiagnosedSilenceableFailure diagSplit =
3008 checkFailureInSplitting(!first.back() && !second.back(), loc);
3009 if (diagSplit.isDefiniteFailure())
3010 return diag;
3011
3012 // Do not add null second parts.
3013 if (!second.back()) {
3014 noSecondPart = target;
3015 second.pop_back();
3016 }
3017 }
3018
3019 if (second.size() != first.size() && !second.empty()) {
3020 auto diag = emitSilenceableError()
3021 << "splitting does not produce the second part for a subset "
3022 "of targets";
3023 diag.attachNote()
3024 << "expected splitting to produce the second part of all "
3025 "or none of the targets";
3026 diag.attachNote(noSecondPart->getLoc())
3027 << "first target with no second part";
3028 return diag;
3029 }
3030
3031 opList.append(first);
3032 if (!second.empty())
3033 opList.append(second);
3034 }
3035 results.set(cast<OpResult>(getSplitList()), opList);
3037}
3038
3039void SplitOp::getEffects(
3040 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3041 consumesHandle(getTargetMutable(), effects);
3042 if (getDynamicChunkSizes())
3043 onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
3044 producesHandle(getOperation()->getOpResults(), effects);
3045 modifiesPayload(effects);
3046}
3047
3048ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
3049 OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
3050 IntegerAttr staticChunkSizes;
3051 if (parser.parseOperand(target) || parser.parseKeyword("after"))
3052 return failure();
3053
3054 OptionalParseResult dynamicPointParseResult =
3055 parser.parseOptionalOperand(dynamicChunkSizes);
3056 if (!dynamicPointParseResult.has_value()) {
3057 int64_t staticChunkSizesValue;
3058 if (failed(parser.parseInteger(staticChunkSizesValue)))
3059 return failure();
3060
3061 staticChunkSizes =
3062 parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
3063 }
3064
3065 Type targetType;
3066 if (parser.parseOptionalAttrDict(result.attributes) ||
3067 parser.parseColonType(targetType) ||
3068 parser.resolveOperand(target, targetType, result.operands)) {
3069 return failure();
3070 }
3071 if (dynamicPointParseResult.has_value()) {
3072 Type chunkSizesType;
3073 if (failed(*dynamicPointParseResult) || parser.parseComma() ||
3074 parser.parseType(chunkSizesType) ||
3075 parser.resolveOperand(dynamicChunkSizes, chunkSizesType,
3076 result.operands)) {
3077 return failure();
3078 }
3079
3080 staticChunkSizes =
3081 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
3082 }
3083
3084 result.addAttribute(
3085 SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
3086 staticChunkSizes);
3087 result.addTypes(targetType);
3088 return success();
3089}
3090
3091void SplitOp::print(OpAsmPrinter &printer) {
3092 printer << " " << getTarget() << " after ";
3093 int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
3094 if (staticChunkSize != ShapedType::kDynamic)
3095 printer << staticChunkSize;
3096 else
3097 printer << getDynamicChunkSizes();
3098 printer << " ";
3099 printer.printOptionalAttrDict(getOperation()->getAttrs(),
3100 {getStaticChunkSizesAttrName()});
3101 printer << " : " << getTarget().getType();
3102 if (staticChunkSize == ShapedType::kDynamic)
3103 printer << ", " << getDynamicChunkSizes().getType();
3104}
3105
3106LogicalResult SplitOp::verify() {
3107 if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
3108 (getDynamicChunkSizes() == nullptr)) {
3109 return emitOpError() << "expects either a dynamic or a static split "
3110 "point to be provided";
3111 }
3112 return success();
3113}
3114
3115//===----------------------------------------------------------------------===//
3116// SplitReductionOp
3117//===----------------------------------------------------------------------===//
3118
3119void transform::SplitReductionOp::build(
3120 OpBuilder &builder, OperationState &result, Value target,
3121 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
3122 bool useScalingAlgorithm, bool useAlloc) {
3123 MLIRContext *ctx = builder.getContext();
3124 result.addOperands(target);
3125 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
3126 builder.getI64IntegerAttr(splitFactor));
3127 result.addAttribute(
3128 SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
3129 builder.getI64IntegerAttr(insertSplitDimension));
3130 if (innerParallel) {
3131 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
3132 builder.getUnitAttr());
3133 }
3134 if (useScalingAlgorithm) {
3135 result.addAttribute(
3136 SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
3137 builder.getUnitAttr());
3138 }
3139 if (useAlloc) {
3140 result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
3141 builder.getUnitAttr());
3142 }
3143 auto resultType = transform::AnyOpType::get(ctx);
3144 result.addTypes({resultType, resultType, resultType, resultType});
3145}
3146
3147DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
3148 transform::TransformRewriter &rewriter, LinalgOp target,
3149 transform::ApplyToEachResultList &results,
3150 transform::TransformState &state) {
3151 ControlSplitReductionFn splitFn = [&](LinalgOp) {
3152 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
3153 unsigned(getInsertSplitDimension()),
3154 bool(getInnerParallel())};
3155 };
3156 rewriter.setInsertionPoint(target);
3157 FailureOr<SplitReductionResult> splitResult =
3158 (getUseScalingAlgorithm())
3159 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
3160 : splitReduction(rewriter, target, splitFn, getUseAlloc());
3161 if (failed(splitResult))
3162 return emitDefaultDefiniteFailure(target);
3163
3164 results.push_back(splitResult->initOrAlloc);
3165 results.push_back(splitResult->fillOp);
3166 results.push_back(splitResult->splitLinalgOp);
3167 results.push_back(splitResult->resultCombiningLinalgOp);
3169}
3170
3171//===----------------------------------------------------------------------===//
3172// TileReductionUsingForOp
3173//===----------------------------------------------------------------------===//
3174
3175void transform::TileReductionUsingForOp::build(
3176 OpBuilder &builder, OperationState &result, Value target,
3177 ArrayRef<int64_t> staticTileSizes) {
3178 // Call the default builder.
3179 // This is future-proof re mixed static-dynamic and setting up the proper
3180 // operands segment sizes attributes for multiple variadic operands.
3181 // In the absence of this, horrible bugs ensue.
3182 // TODO: support mixed static-dynamic (see TileUsingForallOp).
3183 MLIRContext *ctx = builder.getContext();
3184 auto opTy = transform::AnyOpType::get(ctx);
3185 auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
3186 build(builder, result,
3187 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3188 /*target=*/target,
3189 /*reduction_dims=*/nullptr,
3190 /*tile_sizes=*/staticTileSizesAttr);
3191}
3192
3193DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
3194 transform::TransformRewriter &rewriter, Operation *target,
3195 transform::ApplyToEachResultList &results,
3196 transform::TransformState &state) {
3197 rewriter.setInsertionPoint(target);
3198
3199 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3200 if (!partialReductionOp) {
3202 target->getLoc(),
3203 "Operation should implement PartialReductionOpInterface");
3204 }
3205
3206 SmallVector<unsigned> reductionDims =
3207 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3208 if (reductionDims.empty()) {
3209 for (auto [idx, iteratorType] :
3210 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3211 if (iteratorType == utils::IteratorType::reduction)
3212 reductionDims.push_back(idx);
3213 }
3214 }
3215
3216 scf::SCFTilingOptions options;
3217 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
3218 options.setReductionTilingStrategy(
3220 options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
3221 options.setReductionDims(reductionDims);
3222 FailureOr<scf::SCFTilingResult> result =
3223 scf::tileUsingSCF(rewriter, partialReductionOp, options);
3224
3225 if (failed(result)) {
3226 return emitSilenceableFailure(getLoc(),
3227 "failed to tile using partial reduction");
3228 }
3229 rewriter.replaceOp(target, result->replacements);
3230 for (Value initValue : result->initialValues)
3231 results.push_back(initValue.getDefiningOp());
3232 for (auto *parallelTiledOp : result->tiledOps)
3233 results.push_back(parallelTiledOp);
3234 for (auto *mergeOp : result->mergeOps)
3235 results.push_back(mergeOp);
3236 results.push_back(result->loops.front());
3238}
3239
3240//===----------------------------------------------------------------------===//
3241// TileReductionUsingForallOp
3242//===----------------------------------------------------------------------===//
3243
3244void transform::TileReductionUsingForallOp::build(
3245 OpBuilder &builder, OperationState &result, Value target,
3246 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3247 ArrayAttr mapping) {
3248 // Call the default builder.
3249 // This is future-proof re mixed static-dynamic and setting up the proper
3250 // operands segment sizes attributes for multiple variadic operands.
3251 // In the absence of this, horrible bugs ensue.
3252 // TODO: support mixed static-dynamic (see TileUsingForallOp).
3253 MLIRContext *ctx = builder.getContext();
3254 auto opTy = transform::AnyOpType::get(ctx);
3255 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3256 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3257 build(builder, result,
3258 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3259 /*target=*/target,
3260 /*reduction_dims=*/{},
3261 /*num_threads=*/staticNumThreadsAttr,
3262 /*tile_sizes=*/staticTileSizesAttr,
3263 /*mapping=*/mapping);
3264}
3265
3266DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3267 transform::TransformRewriter &rewriter, Operation *target,
3268 transform::ApplyToEachResultList &results,
3269 transform::TransformState &state) {
3270 rewriter.setInsertionPoint(target);
3271
3272 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3273 if (!partialReductionOp) {
3275 target->getLoc(),
3276 "Operation should implement PartialReductionOpInterface");
3277 }
3278 SmallVector<OpFoldResult> numThreads =
3279 getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
3280 SmallVector<OpFoldResult> tileSizes =
3282
3283 scf::SCFTilingOptions options;
3284 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3285 options.setReductionTilingStrategy(
3287 if (!getNumThreads().empty()) {
3288 options.setNumThreads(numThreads);
3289 } else {
3290 options.setTileSizes(tileSizes);
3291 }
3292 if (auto mapping = getMapping()) {
3293 options.setMapping(mapping.value().getValue());
3294 }
3295 SmallVector<unsigned> reductionDims =
3296 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3297 if (reductionDims.empty()) {
3298 for (auto [idx, iteratorType] :
3299 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3300 if (iteratorType == utils::IteratorType::reduction)
3301 reductionDims.push_back(idx);
3302 }
3303 }
3304 options.setReductionDims(reductionDims);
3305 FailureOr<scf::SCFTilingResult> result =
3306 scf::tileUsingSCF(rewriter, partialReductionOp, options);
3307
3308 if (failed(result)) {
3309 auto diag = emitSilenceableError() << "could not tile reduction";
3310 return diag;
3311 }
3312 rewriter.replaceOp(target, result->replacements);
3313
3314 for (Value initValue : result->initialValues)
3315 results.push_back(initValue.getDefiningOp());
3316 for (auto *parallelTiledOp : result->tiledOps)
3317 results.push_back(parallelTiledOp);
3318 for (auto *mergeOp : result->mergeOps)
3319 results.push_back(mergeOp);
3320 results.push_back(result->loops.front());
3322}
3323
3324//===----------------------------------------------------------------------===//
3325// ContinuousTileSizesOp
3326//===----------------------------------------------------------------------===//
3327
3328DiagnosedSilenceableFailure
3329transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3330 TransformResults &transformResults,
3331 TransformState &state) {
3332
3333 SmallVector<Operation *> targetOps =
3334 llvm::to_vector(state.getPayloadOps(getTarget()));
3335
3336 if (!llvm::hasSingleElement(targetOps)) {
3337 return mlir::emitSilenceableFailure(getLoc())
3338 << "requires exactly one target (got " << llvm::range_size(targetOps)
3339 << ")";
3340 }
3341
3342 Operation *target = *targetOps.begin();
3343 auto linalgOp = dyn_cast<LinalgOp>(target);
3344 auto tileableOp = dyn_cast<TilingInterface>(target);
3345
3346 if (!linalgOp)
3347 return emitDefiniteFailure() << "expected Linalg Op";
3348
3349 OpBuilder builder(linalgOp.getContext());
3350
3351 if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
3352 if (linalgOp.hasDynamicShape()) {
3353 auto diag = emitSilenceableError()
3354 << "cannot compute parametric tile sizes for dynamically "
3355 "shaped payload op";
3356 diag.attachNote(linalgOp->getLoc()) << "payload op";
3357 return diag;
3358 }
3359
3360 FailureOr<StaticContinuousTileSizeSpecification> spec =
3361 computeStaticContinuousTileSizes(linalgOp, getDimension(),
3362 getTargetSize());
3363 if (failed(spec)) {
3364 return emitSilenceableError()
3365 << "failed to compute multi-size tiling sizes";
3366 }
3367
3368 SmallVector<int64_t> chunkSizes;
3369
3370 for (auto &&[tileSize, tripCount] :
3371 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3372 chunkSizes.push_back(tileSize * tripCount);
3373
3374 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3375 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3376 return builder.getI64IntegerAttr(value);
3377 });
3378 };
3379 transformResults.setParams(cast<OpResult>(getTileSizes()),
3380 getI64AttrsFromI64(spec->tileSizes));
3381 transformResults.setParams(cast<OpResult>(getChunkSizes()),
3382 getI64AttrsFromI64(chunkSizes));
3383
3385 }
3386
3387 builder.setInsertionPoint(linalgOp);
3388
3389 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
3390 unsigned dimension = getDimension();
3391
3392 FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
3393 builder, tileableOp, dimension, targetSize, true);
3394 if (failed(spec)) {
3395 return emitSilenceableError() << "could not generate tile size computation";
3396 }
3397
3398 AffineExpr s0 = builder.getAffineSymbolExpr(0);
3399 AffineExpr s1 = builder.getAffineSymbolExpr(1);
3400 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3401 return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
3402 ofrs);
3403 };
3404
3405 SmallVector<Value> chunkSizes;
3406 Value splitPoint;
3407 for (auto &&[tileSize, tripCount] :
3408 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3409 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3410 chunkSizes.push_back(splitPoint);
3411 }
3412
3413 auto getDefiningOps = [&](ArrayRef<Value> values) {
3414 return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3415 return value.getDefiningOp();
3416 });
3417 };
3418
3419 transformResults.set(cast<OpResult>(getTileSizes()),
3420 getDefiningOps(spec->tileSizes));
3421 transformResults.set(cast<OpResult>(getChunkSizes()),
3422 getDefiningOps(chunkSizes));
3423
3425}
3426
3427LogicalResult transform::ContinuousTileSizesOp::verify() {
3428
3429 if (getTileSizes().getType() != getChunkSizes().getType()) {
3430 return emitOpError() << "expects all results type to be the same";
3431 }
3432
3433 return success();
3434}
3435
3436void transform::ContinuousTileSizesOp::getEffects(
3437 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3438 if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
3439 onlyReadsPayload(effects);
3440 else
3441 modifiesPayload(effects);
3442 onlyReadsHandle(getTargetMutable(), effects);
3443 producesHandle(getOperation()->getOpResults(), effects);
3444}
3445
3447 Type targetType, Type tileSizes,
3448 Type) {
3449 printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes});
3450}
3451
3453 Type &targetType,
3454 Type &tileSizesType,
3455 Type &chunkSizesType) {
3456 FunctionType funcType;
3457 llvm::SMLoc typeLoc = parser.getCurrentLocation();
3458 if (failed(parser.parseType<FunctionType>(funcType)))
3459 return failure();
3460
3461 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3462 parser.emitError(typeLoc) << "expects a trailing functional type with one "
3463 "argument and one result";
3464 }
3465 targetType = funcType.getInput(0);
3466 tileSizesType = chunkSizesType = funcType.getResult(0);
3467
3468 return success();
3469}
3470
3471//===----------------------------------------------------------------------===//
3472// TileUsingForOp
3473//===----------------------------------------------------------------------===//
3474
3475void transform::TileUsingForOp::build(
3476 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3477 Value target, ArrayRef<int64_t> staticTileSizes,
3478 ArrayRef<int64_t> interchange,
3479 std::optional<ArrayRef<bool>> scalableSizes) {
3480 return build(builder, result, loopTypes,
3481 /*target=*/target,
3482 /*mixedTileSizes=*/
3483 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3484 interchange, scalableSizes);
3485}
3486
3487void transform::TileUsingForOp::build(
3488 OpBuilder &builder, OperationState &result, Value target,
3489 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3490 std::optional<ArrayRef<bool>> scalableSizes) {
3491 build(builder, result, target,
3492 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3493 interchange, scalableSizes);
3494}
3495
3496void transform::TileUsingForOp::build(
3497 OpBuilder &builder, OperationState &result, Value target,
3498 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3499 std::optional<ArrayRef<bool>> scalableSizes) {
3500 // Loop types are automaticaly splat by the callee, setting up one is
3501 // enough.
3502 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
3503 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3504 scalableSizes);
3505}
3506
3507void transform::TileUsingForOp::build(
3508 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3509 Value target, ArrayRef<OpFoldResult> mixedTileSizes,
3510 ArrayRef<int64_t> interchange,
3511 std::optional<ArrayRef<bool>> scalableSizes) {
3512 SmallVector<int64_t> staticTileSizes;
3513 SmallVector<Value> dynamicTileSizes;
3514 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3515 // Call the default builder which sets up the proper operands segment sizes
3516 // attributes for multiple variadic operands. In the absence of this,
3517 // horrible bugs ensue.
3518 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3519 unsigned numExpectedLoops =
3520 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3521 SmallVector<Type> resultTypes;
3522 resultTypes.reserve(numExpectedLoops);
3523 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3524 "expected one loop type or as many as loops");
3525 if (loopTypes.size() == 1)
3526 resultTypes.append(numExpectedLoops, loopTypes[0]);
3527 else
3528 llvm::append_range(resultTypes, loopTypes);
3529 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
3530 if (scalableSizes.has_value())
3531 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3532 build(builder, result, /*tiled_linalg_op=*/target.getType(),
3533 /*loops=*/resultTypes,
3534 /*target=*/target,
3535 /*dynamic_sizes=*/dynamicTileSizes,
3536 /*static_sizes=*/staticTileSizesAttr,
3537 /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
3538 /*scalable_sizes=*/expandedScalableSizes);
3539}
3540
3541LogicalResult transform::TileUsingForOp::verify() {
3542 if (getMixedSizes().size() != getScalableSizes().size())
3543 return emitOpError("expected same number of sizes (")
3544 << getMixedSizes().size() << ") and scalable sizes ("
3545 << getScalableSizes().size() << ")";
3546 ArrayRef<int64_t> staticSizes = getStaticSizes();
3547 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3548 if (getLoops().size() != numExpectedLoops)
3549 return emitOpError("expected number of loops to tile (")
3550 << numExpectedLoops << ") to match number of `loops` results ("
3551 << getLoops().size() << ")";
3552 return success();
3553}
3554
3555DiagnosedSilenceableFailure
3556transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3557 TransformResults &transformResults,
3558 TransformState &state) {
3559 ArrayRef<int64_t> tileSizes = getStaticSizes();
3560
3561 SmallVector<Operation *> targets =
3562 llvm::to_vector(state.getPayloadOps(getTarget()));
3563 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3564 SmallVector<SmallVector<int64_t>> paramSizes;
3565 dynamicSizeProducers.reserve(getDynamicSizes().size());
3566 paramSizes.reserve(getDynamicSizes().size());
3567 for (Value transformValue : getDynamicSizes()) {
3568 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
3569 dynamicSizeProducers.push_back({});
3570 ArrayRef<Attribute> params = state.getParams(transformValue);
3571 paramSizes.push_back(llvm::map_to_vector(params, [](Attribute attr) {
3572 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3573 }));
3574
3575 if (paramSizes.back().size() != targets.size()) {
3576 DiagnosedSilenceableFailure diag =
3577 emitSilenceableError()
3578 << "expected as many parameter values ("
3579 << dynamicSizeProducers.back().size() << ") as target ops ("
3580 << targets.size() << ")";
3581 diag.attachNote(transformValue.getLoc()) << "for this parameter";
3582 return diag;
3583 }
3584
3585 continue;
3586 }
3587 paramSizes.push_back({});
3588 dynamicSizeProducers.push_back(
3589 llvm::to_vector(state.getPayloadOps(transformValue)));
3590
3591 if (dynamicSizeProducers.back().size() != targets.size()) {
3592 DiagnosedSilenceableFailure diag =
3593 emitSilenceableError()
3594 << "expected as many dynamic size-producing operations ("
3595 << dynamicSizeProducers.back().size() << ") as target ops ("
3596 << targets.size() << ")";
3597 diag.attachNote(transformValue.getLoc()) << "for this handle";
3598 return diag;
3599 }
3600
3601 for (Operation *op : dynamicSizeProducers.back()) {
3602 if (op->getNumResults() == 1 &&
3603 isa<IndexType>(op->getResult(0).getType())) {
3604 continue;
3605 }
3606
3607 DiagnosedSilenceableFailure diag =
3608 emitSilenceableError() << "expected sizes to be produced by ops "
3609 "with a single index-type result";
3610 diag.attachNote(op->getLoc()) << "size producer op";
3611 diag.attachNote(transformValue.getLoc()) << "for this handle";
3612 return diag;
3613 }
3614 }
3615
3616 SmallVector<Operation *> tiled;
3617 SmallVector<SmallVector<Operation *, 4>, 4> loops;
3618 loops.resize(getLoops().size());
3619 auto scalableSizes = getScalableSizes();
3620 for (auto [i, op] : llvm::enumerate(targets)) {
3621 auto tilingInterface = dyn_cast<TilingInterface>(op);
3622 if (!tilingInterface) {
3623 DiagnosedSilenceableFailure diag =
3624 emitSilenceableError()
3625 << "only ops implementing TilingInterface are supported";
3626 diag.attachNote(op->getLoc()) << "target op";
3627 return diag;
3628 }
3629 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3630 DiagnosedSilenceableFailure diag =
3631 emitSilenceableError()
3632 << "too many tiles provided, expected at most "
3633 << tilingInterface.getLoopIteratorTypes().size() << " found "
3634 << tileSizes.size();
3635 diag.attachNote(op->getLoc()) << "target op";
3636 return diag;
3637 }
3638
3639 scf::SCFTilingOptions tilingOptions;
3640 if (tileSizes.empty()) {
3641 tilingOptions.setTileSizeComputationFunction(
3642 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
3643 return {};
3644 });
3645 } else {
3646 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3647 Operation *) {
3648 SmallVector<OpFoldResult> sizes;
3649 sizes.reserve(tileSizes.size());
3650 unsigned dynamicIdx = 0;
3651
3652 for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3653 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3654 if (scalableSizes[ofrIdx]) {
3656 b, getLoc(), cast<IntegerAttr>(attr).getInt());
3657 Value vscale =
3658 vector::VectorScaleOp::create(b, getLoc(), b.getIndexType());
3659 sizes.push_back(
3660 arith::MulIOp::create(b, getLoc(), val, vscale).getResult());
3661 } else {
3662 sizes.push_back(attr);
3663 }
3664 continue;
3665 }
3666 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3667 ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3668 ++dynamicIdx;
3669 assert((dynamicSizes.empty() ^ params.empty()) &&
3670 "expected either dynamic sizes or parameters");
3671 if (!params.empty()) {
3672 sizes.push_back(b.getIndexAttr(params[index]));
3673 } else {
3674 sizes.push_back(dynamicSizes[index]->getResult(0));
3675 }
3676 }
3677 return sizes;
3678 });
3679 }
3680
3681 tilingOptions.setInterchange(getInterchange());
3682 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3683 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3684 if (failed(maybeTilingResult))
3686
3687 rewriter.replaceOp(op, maybeTilingResult->replacements);
3688
3689 tiled.append(maybeTilingResult->tiledOps);
3690 for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3691 loops[en2.index()].push_back(en2.value());
3692 }
3693
3694 transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3695 for (const auto &en : llvm::enumerate(loops))
3696 transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3697
3699}
3700
3701SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
3702 ValueRange dynamic = getDynamicSizes();
3703 ArrayRef<int64_t> tileSizes = getStaticSizes();
3704 SmallVector<OpFoldResult> results;
3705 results.reserve(tileSizes.size());
3706 unsigned dynamicPos = 0;
3707 Builder builder(getContext());
3708 for (int64_t size : tileSizes) {
3709 if (size == ShapedType::kDynamic) {
3710 results.push_back(dynamic[dynamicPos++]);
3711 } else {
3712 results.push_back(builder.getIndexAttr(size));
3713 }
3714 }
3715 return results;
3716}
3717
3718void transform::TileUsingForOp::getEffects(
3719 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3720 consumesHandle(getTargetMutable(), effects);
3721 onlyReadsHandle(getDynamicSizesMutable(), effects);
3722 producesHandle(getOperation()->getOpResults(), effects);
3723 modifiesPayload(effects);
3724}
3725
3726//===----------------------------------------------------------------------===//
3727// TileUsingForallOp
3728//===----------------------------------------------------------------------===//
3729
3730void transform::TileUsingForallOp::build(OpBuilder &builder,
3731 OperationState &result, Value target,
3732 ArrayRef<int64_t> staticTileSizes,
3733 transform::TileSizesSpec,
3734 ArrayAttr mapping) {
3735 return build(builder, result,
3736 /*target=*/target,
3737 /*mixedTileSizes=*/
3738 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3739 /*_=*/TileSizesSpec(),
3740 /*mapping=*/mapping);
3741}
3742
3743void transform::TileUsingForallOp::build(OpBuilder &builder,
3744 OperationState &result, Value target,
3745 ArrayRef<OpFoldResult> mixedTileSizes,
3746 transform::TileSizesSpec,
3747 ArrayAttr mapping) {
3748 SmallVector<int64_t> staticTileSizes;
3749 SmallVector<Value> dynamicTileSizes;
3750 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3751 // Call the default builder which sets up the proper operands segment sizes
3752 // attributes for multiple variadic operands. In the absence of this,
3753 // horrible bugs ensue.
3754 MLIRContext *ctx = builder.getContext();
3755 auto operationType = transform::AnyOpType::get(ctx);
3756 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3757 build(builder, result,
3758 /*resultTypes=*/TypeRange{operationType, operationType},
3759 /*target=*/target,
3760 /*num_threads=*/ValueRange{},
3761 /*tile_sizes=*/dynamicTileSizes,
3762 /*packed_num_threads=*/Value(),
3763 /*packed_tile_sizes=*/Value(),
3764 /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3765 /*static_tile_sizes=*/staticTileSizesAttr,
3766 /*mapping=*/mapping);
3767}
3768
3769void transform::TileUsingForallOp::build(OpBuilder &builder,
3770 OperationState &result, Value target,
3771 ArrayRef<int64_t> staticNumThreads,
3772 transform::NumThreadsSpec,
3773 ArrayAttr mapping) {
3774 return build(builder, result, target,
3775 getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3776 NumThreadsSpec(), mapping);
3777}
3778
3779void transform::TileUsingForallOp::build(OpBuilder &builder,
3780 OperationState &result, Value target,
3781 ArrayRef<OpFoldResult> mixedNumThreads,
3782 transform::NumThreadsSpec,
3783 ArrayAttr mapping) {
3784 SmallVector<int64_t> staticNumThreads;
3785 SmallVector<Value> dynamicNumThreads;
3786 dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3787 staticNumThreads);
3788 // Call the default builder which sets up the proper operands segment sizes
3789 // attributes for multiple variadic operands. In the absence of this,
3790 // horrible bugs ensue.
3791 MLIRContext *ctx = builder.getContext();
3792 auto operationType = transform::AnyOpType::get(ctx);
3793 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3794 build(builder, result,
3795 /*resultTypes=*/TypeRange{operationType, operationType},
3796 /*target=*/target,
3797 /*num_threads=*/dynamicNumThreads,
3798 /*tile_sizes=*/ValueRange{},
3799 /*packed_num_threads=*/Value(),
3800 /*packed_tile_sizes=*/Value(),
3801 /*static_num_threads=*/staticNumThreadsAttr,
3802 /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3803 /*mapping=*/mapping);
3804}
3805
3806/// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3807/// normalized upper bound.
3808static SmallVector<OpFoldResult>
3811 ArrayRef<OpFoldResult> steps) {
3812 AffineExpr s0, s1, s2;
3813 bindSymbols(rewriter.getContext(), s0, s1, s2);
3814 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3815 SmallVector<OpFoldResult> normalizedUbs;
3816 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3818 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3819 normalizedUbs.push_back(normalizedUb);
3820 }
3821 return normalizedUbs;
3822}
3823
3824/// When a loop is normalized, the uses of the induction variable within the
3825/// loop need to replaced with `original_lb + old_iv * original_step`.
3827 Location loc, ValueRange ivs,
3829 ArrayRef<OpFoldResult> steps) {
3830 AffineExpr s0, s1;
3831 AffineExpr d0;
3832 bindSymbols(rewriter.getContext(), s0, s1);
3833 bindDims(rewriter.getContext(), d0);
3834 AffineExpr denormExpr = s0 + d0 * s1;
3835 SmallVector<Value> denormalizedIvs;
3836
3837 for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3839 rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3840 denormalizedIvs.push_back(
3841 getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3842 }
3843 return denormalizedIvs;
3844}
3845
3846/// Given a `scf.forall` loop return a loop op with the loop bounds
3847/// normalized.
3848/// TODO: Replace this with a general utility to normalize `scf.forall`.
3849/// At the time of writing, this wasnt done since adding this to `scf`
3850/// dialect would disallow using of `affine.apply` operations due
3851/// to cyclic dependencies. To avoid churn in lit tests
3852/// with the change this was added with, defer that to a follow up.
3853static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3854 scf::ForallOp loop) {
3855 SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3856 SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3857 SmallVector<OpFoldResult> steps = loop.getMixedStep();
3858
3859 if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
3860 return loop;
3861 }
3862
3863 Location loc = loop.getLoc();
3864 SmallVector<OpFoldResult> normalizedUbs =
3865 normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3866 SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3867 rewriter.getIndexAttr(0));
3868 SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3869 rewriter.getIndexAttr(1));
3870
3871 auto normalizedForallOp = scf::ForallOp::create(
3872 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3873 loop.getOutputs(), loop.getMapping(),
3874 [](OpBuilder &, Location, ValueRange) {});
3875
3876 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3877 OpBuilder::InsertionGuard g(rewriter);
3878 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3879 rewriter.setInsertionPointToStart(normalizedLoopBlock);
3880
3881 SmallVector<Value> argValues =
3882 denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3883 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3884 normalizedForallOp.getRegionIterArgs().end());
3885 Block *origLoopBlock = loop.getBody();
3886 rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3887
3888 rewriter.replaceOp(loop, normalizedForallOp);
3889 return normalizedForallOp;
3890}
3891
3893 RewriterBase &rewriter, transform::TransformState &state,
3894 TransformOpInterface transformOp, Operation *target,
3895 ArrayRef<OpFoldResult> mixedNumThreads,
3896 ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3897 scf::SCFTilingResult &tilingResult) {
3898 // Transform all targets one by one.
3899 auto tileableOp = dyn_cast<TilingInterface>(target);
3900 if (!tileableOp) {
3902 transformOp.emitSilenceableError()
3903 << "only TilingInterface ops are supported";
3904 diag.attachNote(target->getLoc()) << "target op";
3905 return diag;
3906 }
3907 rewriter.setInsertionPoint(tileableOp);
3908 scf::SCFTilingOptions options;
3909 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3910 if (!mixedNumThreads.empty()) {
3911 options.setNumThreads(mixedNumThreads);
3912 } else {
3913 options.setTileSizes(mixedTileSizes);
3914 }
3915 if (mapping) {
3916 options.setMapping(mapping.value().getValue());
3917 }
3918 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3919 scf::tileUsingSCF(rewriter, tileableOp, options);
3920
3921 if (failed(maybeTilingResult))
3922 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3923
3924 rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3925
3926 tilingResult = *maybeTilingResult;
3927
3928 // Rank-0 ops produce no loops; skip normalization when there is nothing
3929 // to normalize.
3930 if (mixedNumThreads.empty() && !tilingResult.loops.empty()) {
3931 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3932 OpBuilder::InsertionGuard g(rewriter);
3933 rewriter.setInsertionPoint(generatedForallOp);
3934 scf::ForallOp normalizedForallOp =
3935 normalizeForallLoopOp(rewriter, generatedForallOp);
3936 tilingResult.loops.front() = normalizedForallOp;
3937 }
3938
3940}
3941
3942DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3944 transform::TransformResults &transformResults,
3946 auto transformOp = cast<TransformOpInterface>(getOperation());
3947
3948 // Result payload ops.
3950 SmallVector<Operation *> tiledOps;
3951
3952 // Unpack handles.
3953 SmallVector<OpFoldResult> mixedNumThreads;
3955 getPackedNumThreads()
3957 state, transformOp, mixedNumThreads, getPackedNumThreads())
3959 state, transformOp, mixedNumThreads, getMixedNumThreads());
3960 if (!status.succeeded())
3961 return status;
3962 SmallVector<OpFoldResult> mixedTileSizes;
3963 status = getPackedTileSizes()
3965 state, transformOp, mixedTileSizes, getPackedTileSizes())
3967 state, transformOp, mixedTileSizes, getMixedTileSizes());
3968 if (!status.succeeded())
3969 return status;
3970
3971 for (Operation *target : state.getPayloadOps(getTarget())) {
3972 scf::SCFTilingResult tilingResult;
3974 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3975 getMapping(), tilingResult);
3976 if (!diag.succeeded())
3977 return diag;
3978 if (!tilingResult.loops.empty())
3979 tileOps.push_back(tilingResult.loops.front());
3980 tiledOps.append(tilingResult.tiledOps);
3981 }
3982
3983 transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3984 transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3985
3987}
3988
3989void transform::TileUsingForallOp::getEffects(
3990 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3991 consumesHandle(getTargetMutable(), effects);
3992 onlyReadsHandle(getTileSizesMutable(), effects);
3993 onlyReadsHandle(getNumThreadsMutable(), effects);
3994 onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3995 onlyReadsHandle(getPackedTileSizesMutable(), effects);
3996 producesHandle(getOperation()->getOpResults(), effects);
3997 modifiesPayload(effects);
3998}
3999
4000SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
4001 Builder b(getContext());
4002 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
4003}
4004
4005SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
4006 Builder b(getContext());
4007 return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
4008}
4009
4010LogicalResult TileUsingForallOp::verify() {
4011 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
4012 static_cast<int>(getPackedNumThreads() != Value());
4013 if (numThreadsSpec > 1)
4014 return emitOpError(
4015 "num_threads and packed_num_threads are mutually exclusive");
4016 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
4017 static_cast<int>(getPackedTileSizes() != Value());
4018 if (tileSizesSpec > 1)
4019 return emitOpError(
4020 "tile_sizes and packed_tile_sizes are mutually exclusive");
4021 if (numThreadsSpec == 0 && tileSizesSpec == 0)
4022 return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
4023 "must be specified");
4024 return success();
4025}
4026
4027//===----------------------------------------------------------------------===//
4028// VectorizeChildrenAndApplyPatternsOp
4029//===----------------------------------------------------------------------===//
4030
4031void transform::VectorizeChildrenAndApplyPatternsOp::build(
4032 OpBuilder &builder, OperationState &result, Value target,
4033 bool foldTypeExtensionsIntoContract, bool vectorizePadding,
4034 bool vectorizeExtract, bool flatten1DDepthwiseConv) {
4035 result.addOperands(target);
4036 if (foldTypeExtensionsIntoContract) {
4037 result.addAttribute(
4038 VectorizeChildrenAndApplyPatternsOp::
4039 getFoldTypeExtensionsIntoContractAttrName(result.name),
4040 builder.getUnitAttr());
4041 }
4042 if (vectorizePadding) {
4043 result.addAttribute(
4044 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
4045 result.name),
4046 builder.getUnitAttr());
4047 }
4048 if (vectorizeExtract) {
4049 result.addAttribute(
4050 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
4051 result.name),
4052 builder.getUnitAttr());
4053 }
4054 if (flatten1DDepthwiseConv) {
4055 result.addAttribute(
4056 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
4057 result.name),
4058 builder.getUnitAttr());
4059 }
4060 result.addTypes(transform::AnyOpType::get(builder.getContext()));
4061}
4062
4063namespace {
4064/// This is an helper only to call vectorize via a pattern inside of
4065/// VectorizeChildrenAndApplyPatternsOp::applyToOne.
4066struct VectorizationPattern : public RewritePattern {
4067 explicit VectorizationPattern(MLIRContext *context,
4068 bool vectorizeExtract = false,
4069 bool flattenConv = false)
4070 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
4071 vectorizeNDExtract(vectorizeExtract),
4072 flatten1DDepthwiseConv(flattenConv) {}
4073 LogicalResult matchAndRewrite(Operation *op,
4074 PatternRewriter &rewriter) const override {
4076 return rewriter.notifyMatchFailure(op,
4077 "Unsupported Op, cannot vectorize");
4078 FailureOr<VectorizationResult> vectorResults =
4079 vectorize(rewriter, op, /*inputVectorSizes=*/{},
4080 /*inputScalableVecDims=*/{}, vectorizeNDExtract,
4081 flatten1DDepthwiseConv);
4082 if (failed(vectorResults))
4083 return failure();
4084 rewriter.replaceOp(op, vectorResults->replacements);
4085 return success();
4086 }
4087
4088private:
4089 /// Controls whether to vectorize `tensor.extract` when the input tensor is
4090 /// rank >= 2.
4091 bool vectorizeNDExtract = false;
4092 /// Controls whether to "flatten" the channel dimension when vectorising 1D
4093 /// depthwise convolutions. This should lead to bette vectorization for
4094 /// tensors with a low number of channel dimensions.
4095 bool flatten1DDepthwiseConv = false;
4096};
4097} // namespace
4098
4099DiagnosedSilenceableFailure
4100transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
4101 transform::TransformRewriter &rewriter, Operation *target,
4102 transform::ApplyToEachResultList &results,
4103 transform::TransformState &state) {
4104 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
4105 auto diag = this->emitOpError("requires isolated-from-above targets");
4106 diag.attachNote(target->getLoc()) << "non-isolated target";
4108 }
4109
4110 MLIRContext *ctx = getContext();
4111 RewritePatternSet patterns(ctx);
4112 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
4113 getFlatten_1dDepthwiseConv());
4114
4115 if (!getDisableTransferPermutationMapLoweringPatterns())
4117
4118 if (!getDisableMultiReductionToContractPatterns())
4120
4122
4123 patterns.add<linalg::LinalgCopyVTRForwardingPattern,
4124 linalg::LinalgCopyVTWForwardingPattern>(ctx,
4125 /*benefit=*/2);
4126 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
4127 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
4129
4130 patterns.add<CopyVectorizationPattern>(ctx);
4131
4132 if (getFoldTypeExtensionsIntoContract())
4134
4135 if (getVectorizePadding()) {
4137 // This creates an alternative path for lowering tensor.pad - by
4138 // decomposing it into e.g. linalg.fill.
4140 }
4142
4143 TrackingListener listener(state, *this);
4144 if (failed(
4145 applyPatternsGreedily(target, std::move(patterns),
4146 GreedyRewriteConfig().setListener(&listener))))
4147 return emitDefaultDefiniteFailure(target);
4148
4149 results.push_back(target);
4151}
4152
4153//===----------------------------------------------------------------------===//
4154// VectorizeOp
4155//===----------------------------------------------------------------------===//
4156
4157DiagnosedSilenceableFailure transform::VectorizeOp::apply(
4158 transform::TransformRewriter &rewriter,
4159 mlir::transform::TransformResults &transformResults,
4160 mlir::transform::TransformState &state) {
4161 auto targets = state.getPayloadOps(getTarget());
4162 if (std::empty(targets))
4164 auto transformOp = cast<TransformOpInterface>(getOperation());
4165 SmallVector<int64_t> vectorSizes;
4166 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
4167 state, transformOp, getMixedVectorSizes(), vectorSizes);
4168 if (!status.succeeded())
4169 return status;
4170
4171 // TODO: Check that the correct number of vectorSizes was provided.
4172 for (Operation *target : targets) {
4174 return mlir::emitSilenceableFailure(target->getLoc())
4175 << "Unsupported Op, cannot vectorize";
4176 }
4177 FailureOr<VectorizationResult> vectorResults =
4178 linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
4179 getVectorizeNdExtract().value_or(false),
4180 /*flatten1DDepthwiseConv=*/false,
4181 getAssumeDynamicDimsMatchVecSizes().value_or(false),
4182 getCreateNamedContraction().value_or(false));
4183 if (failed(vectorResults)) {
4184 return mlir::emitSilenceableFailure(target->getLoc())
4185 << "Attempted to vectorize, but failed";
4186 }
4187 rewriter.replaceOp(target, vectorResults->replacements);
4188 }
4189
4191}
4192
4193void transform::VectorizeOp::getEffects(
4194 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4195 consumesHandle(getTargetMutable(), effects);
4196 onlyReadsHandle(getVectorSizesMutable(), effects);
4197 modifiesPayload(effects);
4198}
4199
4200SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
4201 OpBuilder b(getContext());
4202 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
4203}
4204
4205LogicalResult transform::VectorizeOp::verify() {
4206 if (getStaticVectorSizes().size() != getScalableSizes().size())
4207 return emitOpError("expected same number of vector sizes (")
4208 << getStaticVectorSizes().size() << ") and scalable sizes ("
4209 << getScalableSizes().size() << ")";
4210 return success();
4211}
4212
4213//===----------------------------------------------------------------------===//
4214// HoistRedundantVectorTransfersOp
4215//===----------------------------------------------------------------------===//
4216
4217DiagnosedSilenceableFailure
4218transform::HoistRedundantVectorTransfersOp::applyToOne(
4219 transform::TransformRewriter &rewriter, func::FuncOp target,
4220 transform::ApplyToEachResultList &results,
4221 transform::TransformState &state) {
4222 // WARNING: This hoisting does not model parallelism and is generally
4223 // incorrect when used on distributed loops with memref semantics!
4224 // TODO: obsolete and should be retired.
4225 linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
4226 results.push_back(target);
4228}
4229
4230//===----------------------------------------------------------------------===//
4231// HoistRedundantVectorBroadcastsOp
4232//===----------------------------------------------------------------------===//
4233
4234DiagnosedSilenceableFailure
4235transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4236 transform::TransformRewriter &rewriter, mlir::Operation *target,
4237 transform::ApplyToEachResultList &results,
4238 transform::TransformState &state) {
4239 rewriter.setInsertionPoint(target);
4241 results.push_back(target);
4243}
4244
4245//===----------------------------------------------------------------------===//
4246// ConvertConv2DToImg2ColOp.
4247//===----------------------------------------------------------------------===//
4248
4249DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
4250 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4251 transform::ApplyToEachResultList &results,
4252 transform::TransformState &state) {
4253 rewriter.setInsertionPoint(target);
4254 auto maybeTransformed =
4256 target)
4257 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4258 return rewriteInIm2Col(rewriter, op);
4259 })
4260 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4261 return rewriteInIm2Col(rewriter, op);
4262 })
4263 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4264 return rewriteInIm2Col(rewriter, op);
4265 })
4266 .Case([&](linalg::Conv2DNchwFchwOp op) {
4267 return rewriteInIm2Col(rewriter, op);
4268 })
4269 .Default([&](Operation *op) {
4270 return rewriter.notifyMatchFailure(op, "not supported");
4271 });
4272 if (failed(maybeTransformed))
4273 return emitDefaultSilenceableFailure(target);
4274 // Handle to the operation producing the img2col tensor.
4275 results.push_back(maybeTransformed->first);
4276 // Handle to the operation that replaces the original convolution.
4277 results.push_back(maybeTransformed->second);
4279}
4280
4281//===----------------------------------------------------------------------===//
4282// FlattenElementwiseLinalgOp.
4283//===----------------------------------------------------------------------===//
4284
4285DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4286 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4287 transform::ApplyToEachResultList &results,
4288 transform::TransformState &state) {
4289 rewriter.setInsertionPoint(target);
4290 if (!isElementwise(target))
4291 return mlir::emitSilenceableFailure(target->getLoc())
4292 << "only elementwise flattening is supported";
4293
4294 // If rank <= 1, do nothing
4295 if (target.getNumLoops() <= 1) {
4296 results.push_back(target);
4298 }
4299
4300 // Attempt to flatten all dims to one.
4301 ReassociationIndices reassociation(target.getNumLoops());
4302 std::iota(reassociation.begin(), reassociation.end(), 0);
4303 auto maybeFlattened =
4304 collapseOpIterationDims(target, reassociation, rewriter);
4305 if (failed(maybeFlattened))
4306 return mlir::emitSilenceableFailure(target->getLoc())
4307 << "attempted to flatten, but failed";
4308 results.push_back(maybeFlattened->collapsedOp);
4309 rewriter.replaceOp(target, maybeFlattened->results);
4311}
4312
4313//===----------------------------------------------------------------------===//
4314// TransposeConv2DOp
4315//===----------------------------------------------------------------------===//
4316
4317DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4318 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4319 transform::ApplyToEachResultList &results,
4320 transform::TransformState &state) {
4321 rewriter.setInsertionPoint(target);
4322 auto maybeTransformed =
4324 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4325 return transposeConv2D(rewriter, op);
4326 })
4327 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4328 return transposeConv2D(rewriter, op);
4329 })
4330 .Default([&](Operation *op) {
4331 return rewriter.notifyMatchFailure(op, "not supported");
4332 });
4333 if (failed(maybeTransformed))
4334 return emitDefaultSilenceableFailure(target);
4335 // Handle to the new Conv2D operation with transposed filters
4336 results.push_back(*maybeTransformed);
4338}
4339
4340//===----------------------------------------------------------------------===//
4341// TransposeMatmulOp
4342//===----------------------------------------------------------------------===//
4343
4344DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4345 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4346 transform::ApplyToEachResultList &results,
4347 transform::TransformState &state) {
4348 rewriter.setInsertionPoint(target);
4349 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4350 auto maybeTransformed =
4352 .Case([&](linalg::MatmulOp op) {
4353 return transposeMatmul(rewriter, op, transposeLHS);
4354 })
4355 .Case([&](linalg::BatchMatmulOp op) {
4356 return transposeBatchMatmul(rewriter, op, transposeLHS);
4357 })
4358 .Default(failure());
4359 if (failed(maybeTransformed))
4360 return emitSilenceableFailure(target->getLoc()) << "not supported";
4361 // Handle to the new Matmul operation with transposed filters
4362 results.push_back(*maybeTransformed);
4364}
4365
4366//===----------------------------------------------------------------------===//
4367// InsertSliceToCopyOp
4368//===----------------------------------------------------------------------===//
4369template <typename OpTy>
4370static DiagnosedSilenceableFailure
4371doit(RewriterBase &rewriter, OpTy target,
4374 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4375 tensor::ParallelInsertSliceOp>() &&
4376 "wrong op type");
4377
4378 if (auto copySource =
4379 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4380 results.push_back(copySource);
4382 }
4383
4384 // If we are inside a `ParallelCombiningOp` region, temporarily set the
4385 // insertion point outside: only ops implementing ParallelCombiningOpInterface
4386 // are allowed in there.
4387 if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
4388 rewriter.setInsertionPoint(target->getParentOp());
4389
4390 Value extracted = tensor::ExtractSliceOp::create(
4391 rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
4392 target.getMixedSizes(), target.getMixedStrides());
4393 Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
4394 target.getSource(), extracted)
4395 .getResult(0);
4396 // Reset the insertion point.
4397 rewriter.setInsertionPoint(target);
4398 rewriter.replaceOpWithNewOp<OpTy>(
4399 target, copied, target.getDest(), target.getMixedOffsets(),
4400 target.getMixedSizes(), target.getMixedStrides());
4401
4402 results.push_back(copied.getDefiningOp());
4404}
4405
4406DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4407 transform::TransformRewriter &rewriter, Operation *targetOp,
4408 transform::ApplyToEachResultList &results,
4409 transform::TransformState &state) {
4410
4411 rewriter.setInsertionPoint(targetOp);
4412 if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4413 return doit(rewriter, target, results, state);
4414 if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4415 return doit(rewriter, target, results, state);
4416
4417 DiagnosedSilenceableFailure diag =
4418 emitSilenceableError()
4419 << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4420 diag.attachNote(targetOp->getLoc()) << "target op";
4421 return diag;
4422}
4423
4424//===----------------------------------------------------------------------===//
4425// MapCopyToThreadsOp
4426//===----------------------------------------------------------------------===//
4427
4428DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4429 transform::TransformRewriter &rewriter, Operation *target,
4430 transform::ApplyToEachResultList &results,
4431 transform::TransformState &state) {
4432 // Check if the op is supported.
4433 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4434 DiagnosedSilenceableFailure diag =
4435 emitSilenceableError()
4436 << "only linalg.copy and tensor.pad target ops are supported";
4437 diag.attachNote(target->getLoc()) << "target op";
4438 return diag;
4439 }
4440 assert(target->getNumResults() == 1 && "expected single result");
4441 auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
4442 if (!resultShapedType.hasStaticShape()) {
4443 DiagnosedSilenceableFailure diag =
4444 emitSilenceableError()
4445 << "only statically sized ops of rank <= 3 are supported";
4446 diag.attachNote(target->getLoc()) << "target op";
4447 return diag;
4448 }
4449
4450 // Conservatively set the minimum viable desired bitwidth alignment.
4451 int64_t desiredBitAlignment = getDesiredBitAlignment();
4452 int64_t eltBitwidth =
4453 resultShapedType.getElementType().getIntOrFloatBitWidth();
4454 if (desiredBitAlignment % eltBitwidth != 0) {
4455 desiredBitAlignment = eltBitwidth;
4456 }
4457
4458 gpu::CopyMappingInfo mapping(
4459 /*ctx=*/getContext(),
4460 /*totalNumThreads=*/getTotalNumThreads(),
4461 /*alignment=*/desiredBitAlignment,
4462 /*sizes=*/resultShapedType.getShape(),
4463 /*favorPredication=*/false,
4464 /*elementalBitwidth=*/
4465 resultShapedType.getElementType().getIntOrFloatBitWidth());
4466 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4467 DiagnosedSilenceableFailure diag =
4468 emitSilenceableError()
4469 << "too few threads to map copy op to threads on the most minor "
4470 "dimension, given alignment and vector size constraints, try "
4471 "smaller tile size of mapping to more threads";
4472 diag.attachNote(target->getLoc()) << "target op";
4473 return diag;
4474 }
4475
4476 // OpBuilder only used to compute attributes.
4477 OpBuilder b(getContext());
4478 scf::SCFTilingResult tilingResult;
4479 DiagnosedSilenceableFailure diag = tileToForallOpImpl(
4480 /*rewriter=*/rewriter,
4481 /*state=*/state,
4482 /*transformOp=*/*this,
4483 /*target=*/target,
4484 /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
4485 /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
4486 /*mapping=*/b.getArrayAttr(mapping.threadMapping),
4487 /*tilingResult=*/tilingResult);
4488 if (!diag.succeeded())
4489 return diag;
4490
4491 results.push_back(tilingResult.loops.front());
4492 for (auto *op : tilingResult.tiledOps)
4493 results.push_back(op);
4495}
4496
4497//===----------------------------------------------------------------------===//
4498// WinogradConv2DOp
4499//===----------------------------------------------------------------------===//
4500
4501DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4502 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4503 transform::ApplyToEachResultList &results,
4504 transform::TransformState &state) {
4505 rewriter.setInsertionPoint(target);
4506 FailureOr<Operation *> maybeTransformed = failure();
4507 bool supported = TypeSwitch<Operation *, bool>(target)
4508 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4509 maybeTransformed =
4510 winogradConv2D(rewriter, op, getFmr());
4511 return true;
4512 })
4513 .Default([&](Operation *op) { return false; });
4514
4515 if (!supported) {
4516 return emitSilenceableError()
4517 << "this operation is not supported to convert to Winograd Conv2D";
4518 }
4519
4520 if (failed(maybeTransformed)) {
4521 return emitSilenceableError() << "apply Winograd Conv2D failed";
4522 }
4523
4524 results.push_back(*maybeTransformed);
4526}
4527
4528DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4529 transform::TransformRewriter &rewriter, Operation *target,
4530 transform::ApplyToEachResultList &results,
4531 transform::TransformState &state) {
4532 rewriter.setInsertionPoint(target);
4533 FailureOr<Operation *> maybeTransformed = failure();
4534 bool supported =
4536 .Case([&](linalg::WinogradFilterTransformOp op) {
4537 maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
4538 return true;
4539 })
4540 .Case([&](linalg::WinogradInputTransformOp op) {
4541 maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
4542 return true;
4543 })
4544 .Case([&](linalg::WinogradOutputTransformOp op) {
4545 maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
4546 return true;
4547 })
4548 .Default(false);
4549
4550 if (!supported) {
4551 DiagnosedSilenceableFailure diag =
4552 emitSilenceableError()
4553 << "this operation is not supported to decompose into other operations";
4554 diag.attachNote(target->getLoc()) << "target op";
4555 return diag;
4556 }
4557
4558 if (failed(maybeTransformed)) {
4559 DiagnosedSilenceableFailure diag =
4560 emitSilenceableError() << "decompose Winograd operations failed";
4561 diag.attachNote(target->getLoc()) << "target op";
4562 return diag;
4563 }
4564
4565 results.push_back(*maybeTransformed);
4567}
4568
4569#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4570
4571#define GET_OP_CLASSES
4572#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, bool packedResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized. TODO: Replace this with a g...
ArrayAttr()
static bool mayBeRead(OpOperand &operand)
Return true if the operand may be read from by its owner.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static bool sameOrEquivalentIterArg(Value src, Value dst)
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tileSizes, Type)
static DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static llvm::ManagedStatic< PassManagerOptions > options
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Base type for affine expression.
Definition AffineExpr.h:68
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition Builders.cpp:310
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the error.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:143
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:161
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:339
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition Builders.h:318
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:454
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
OpResult getOpResult(unsigned idx)
Definition Operation.h:446
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:559
void setOperand(unsigned idx, Value value)
Definition Operation.h:376
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:585
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
operand_type_range getOperandTypes()
Definition Operation.h:422
result_type_range getResultTypes()
Definition Operation.h:453
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:288
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
user_range getUsers()
Returns a range of all users.
Definition Operation.h:898
result_range getOpResults()
Definition Operation.h:445
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
Type front()
Return first type in the range.
Definition TypeRange.h:164
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation, bool PoisonPaddingOk=false)
Patterns to bubble up or down data layout ops across other operations.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
Definition Padding.cpp:244
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice + copy.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Canonicalization patterns relevant to apply after tiling patterns.
Definition Tiling.cpp:866
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition Transforms.h:489
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors and memref.
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:217
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition Tiling.cpp:236
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition Tiling.cpp:156
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition Tiling.cpp:106
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition Hoisting.cpp:89
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition Hoisting.cpp:198
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition Split.cpp:68
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition Tiling.cpp:262
FailureOr< LinalgOp > downscaleSizeOneWindowedConvolution(RewriterBase &rewriter, LinalgOp op)
Rewrite convolution/pooling/depthwise ops with size-1 window dimensions into lower-dimensional ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition SCF.cpp:675
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:60
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SmallVector< IntTy > extractFromIntegerArrayAttr(Attribute attr)
Extract integer values from the assumed ArrayAttr of IntegerAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition Builders.h:287
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition Transforms.h:521
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition Transforms.h:409
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition Transforms.h:456
LinalgPromotionOptions & setAlignment(unsigned align)
Definition Transforms.h:433
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition Transforms.h:440
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition Transforms.h:398
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition Transforms.h:420
LinalgPromotionOptions & setUseOriginalSubviewSize(bool originalSize)
Definition Transforms.h:427
LinalgPromotionOptions & setUseAlloca(bool use)
Definition Transforms.h:446
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition Transforms.h:466