MLIR 23.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
40#include "mlir/Support/LLVM.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/ScopeExit.h"
44#include "llvm/ADT/SmallPtrSet.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/DebugLog.h"
48#include "llvm/Support/LogicalResult.h"
49#include <type_traits>
50
51using namespace mlir;
52using namespace mlir::linalg;
53using namespace mlir::transform;
54
55#define DEBUG_TYPE "linalg-transforms"
56
57/// Attempts to apply the pattern specified as template argument to the given
58/// operation. The pattern is expected to have a `returningMatchAndRewrite`
59/// function that returns the "main" result or failure. Returns failure if the
60/// pattern failed to apply. Extra arguments are forwarded to the pattern
61/// constructor.
62template <typename PatternTy, typename... Args>
63static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
64 // Check if the given operation has the type expected by the pattern.
65 using OpTy = typename llvm::function_traits<
66 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
67 auto op = dyn_cast<OpTy>(operation);
68 if (!op)
69 return failure();
70
71 // Apply the pattern directly to the op.
72 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
73 // We want to discourage direct use of PatternRewriter in APIs but In this
74 // very specific case, an IRRewriter is not enough.
75 PatternRewriter rewriter(operation->getContext());
76 rewriter.setInsertionPoint(operation);
77 auto result = pattern.returningMatchAndRewrite(op, rewriter);
78 if (failed(result))
79 return failure();
80 return cast<LinalgOp>(result->getOperation());
81}
82
83/// Assuming that `ofr` is an index attr or a param of index type
84/// or a transform dialect handle mapped to exactly one op
85/// with one index result, return that value.
87 transform::TransformState &state, TransformOpInterface transformOp,
89 for (OpFoldResult ofr : ofrs) {
90 if (auto attr = dyn_cast<Attribute>(ofr)) {
91 if (!isa<IntegerAttr>(attr))
92 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
93 result.push_back(ofr);
94 continue;
95 }
96
97 Value transformValue = cast<Value>(ofr);
98 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
99 ArrayRef<Attribute> params = state.getParams(transformValue);
100 if (params.size() != 1)
101 return transformOp.emitDefiniteFailure()
102 << "requires exactly one parameter associated";
103 result.push_back(params[0]);
104 continue;
105 }
106
107 auto payloadOps = state.getPayloadOps(transformValue);
108 if (!llvm::hasSingleElement(payloadOps)) {
110 transformOp.emitSilenceableError()
111 << "handle must be mapped to exactly one payload op";
112 diag.attachNote(transformValue.getLoc())
113 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
114 return diag;
115 }
116
117 Operation *op = *payloadOps.begin();
118 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
120 transformOp.emitSilenceableError()
121 << "payload op must have exactly 1 index result";
122 diag.attachNote(op->getLoc())
123 << "has " << op->getNumResults() << " results";
124 return diag;
125 }
126 result.push_back(op->getResult(0));
127 }
128
130}
131
132// Given a list of params that are index attrs or a list of OpFoldResults
133// that are either index attrs or op handles, return a list of OpFoldResults
134// of index attrs or a list of OpFoldResults where all op handles are
135// replaced with the first (and only) OpResult of that payload op.
136// (There must be exactly one parameter associated with the AnyParamType or
137// one mapped payload op which must have exactly one index result.)
139 transform::TransformState &state, TransformOpInterface transformOp,
140 SmallVector<OpFoldResult> &result, Value packedHandle) {
141 if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
142 ArrayRef<Attribute> params = state.getParams(packedHandle);
143 for (auto param : params) {
144 if (!isa<IntegerAttr>(param))
145 return transformOp.emitDefiniteFailure()
146 << "expected the parameter to be associated with an integer "
147 "attribute";
148 result.push_back(param);
149 }
151 }
152
153 for (Operation *op : state.getPayloadOps(packedHandle)) {
154 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
156 transformOp.emitSilenceableError()
157 << "payload op must have exactly 1 index result";
158 diag.attachNote(op->getLoc())
159 << "has " << op->getNumResults() << " results";
160 return diag;
161 }
162 result.push_back(op->getResult(0));
163 }
164
166}
167
168/// When possible, converts each `OpFoldResult` in `mixedResult` to
169/// an integer if the value can be statically inferred. If a result
170/// is a `Value` then it must be either a `ParamType` or a handle
171/// to an a constant like op.
173 TransformState &state, TransformOpInterface &transformOp,
174 ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
175 for (OpFoldResult paramOrHandle : mixedResults) {
176 if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
177 reified.push_back(cast<IntegerAttr>(attr).getInt());
178 continue;
179 }
180 if (isa<TransformParamTypeInterface>(
181 cast<Value>(paramOrHandle).getType())) {
182 ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
183 if (params.size() != 1)
184 return transformOp.emitSilenceableError() << "expected a single param";
185 reified.push_back(
186 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
187 continue;
188 }
189
190 Value handle = cast<Value>(paramOrHandle);
191 if (!isa<TransformHandleTypeInterface>(handle.getType()))
192 return transformOp.emitSilenceableError() << "unexpected value handle";
193 auto payload = state.getPayloadOps(handle);
194 if (!llvm::hasSingleElement(payload))
195 return transformOp.emitSilenceableError()
196 << "requires param or handle that is mapped to 1 payload op";
197
198 Operation *paramOrHandlePayloadOp = *payload.begin();
199 if (paramOrHandlePayloadOp->getNumResults() != 1 ||
200 !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
201 return transformOp.emitSilenceableError()
202 << "requires param or handle to be result of op with 1 index "
203 "result";
204 }
205
206 IntegerAttr attr;
207 if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
208 return transformOp.emitSilenceableError()
209 << "requires param or handle to be the result of a constant like "
210 "op";
211
212 reified.push_back(attr.getInt());
213 }
215}
216
217//===----------------------------------------------------------------------===//
218// Apply...PatternsOp
219//===----------------------------------------------------------------------===//
220
221void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
222 RewritePatternSet &patterns) {
224}
225
226void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
227 RewritePatternSet &patterns) {
229}
230
231void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
232 RewritePatternSet &patterns) {
234}
235
236void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
237 RewritePatternSet &patterns) {
240}
241
242void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
243 RewritePatternSet &patterns) {
245 options.rankReductionStrategy =
248}
249
250void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
251 RewritePatternSet &patterns) {
253}
254
255void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
256 RewritePatternSet &patterns) {
258}
259
260void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
261 RewritePatternSet &patterns) {
263}
264
265void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
266 RewritePatternSet &patterns) {
268}
269
270void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
271 RewritePatternSet &patterns) {
273}
274
275void transform::ApplyDataLayoutPropagationPatternsOp::populatePatterns(
276 RewritePatternSet &patterns) {
277 linalg::ControlPropagationFn defaultControlFn = [](OpOperand *operand) {
278 return true;
279 };
280 linalg::populateDataLayoutPropagationPatterns(patterns, defaultControlFn,
281 getPoisonPadding());
282}
283
284void transform::ApplyExtractSliceSinkingPatternsOp::populatePatterns(
285 RewritePatternSet &patterns) {
286 linalg::ControlPropagationFn defaultControlFn =
287 [](OpOperand *opOperand) -> bool {
288 Operation *producer = opOperand->get().getDefiningOp();
289 Operation *consumer = opOperand->getOwner();
290 return consumer->getBlock() == producer->getBlock();
291 };
292 linalg::populateExtractSliceSinkingPatterns(patterns, defaultControlFn);
293}
294
295//===----------------------------------------------------------------------===//
296// BufferizeToAllocationOp
297//===----------------------------------------------------------------------===//
298
299namespace {
300class NewOpsListener : public RewriterBase::ForwardingListener {
301public:
303
304 SmallVector<Operation *> getNewOps() const {
305 return SmallVector<Operation *>(newOps.begin(), newOps.end());
306 }
307
308private:
309 void notifyOperationInserted(Operation *op,
310 OpBuilder::InsertPoint previous) override {
311 ForwardingListener::notifyOperationInserted(op, previous);
312 // We only care about newly created ops.
313 if (previous.isSet())
314 return;
315 auto inserted = newOps.insert(op);
316 (void)inserted;
317 assert(inserted.second && "expected newly created op");
318 }
319
320 void notifyOperationErased(Operation *op) override {
321 ForwardingListener::notifyOperationErased(op);
322 op->walk([&](Operation *op) { newOps.erase(op); });
323 }
324
326};
327} // namespace
328
329DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
332 // Attach listener to keep track of newly created ops.
333 OpBuilder::Listener *previousListener = rewriter.getListener();
334 llvm::scope_exit resetListener(
335 [&]() { rewriter.setListener(previousListener); });
336 NewOpsListener newOpsListener(previousListener);
337 rewriter.setListener(&newOpsListener);
338
340 if (getMemcpyOp() == "bufferization.materialize_in_destination") {
341 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
342 MaterializeInDestination;
343 } else if (getMemcpyOp() == "memref.copy") {
344 options.memcpyOp =
346 } else if (getMemcpyOp() == "linalg.copy") {
347 options.memcpyOp =
349 } else {
350 llvm_unreachable("invalid memcpy op");
351 }
352 if (getAllocOp() == "memref.alloc") {
353 options.allocOp =
355 } else if (getAllocOp() == "memref.alloca") {
356 options.allocOp =
358 } else {
359 llvm_unreachable("invalid alloc op");
360 }
361 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
362 options.emitDealloc = getEmitDealloc();
363
364 // Bufferize ops.
365 Attribute memorySpace =
366 getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
367 SmallVector<Value> allocatedBuffers;
368 for (Operation *op : state.getPayloadOps(getTarget())) {
369 Value buffer =
370 linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
371 if (!buffer) {
372 DiagnosedSilenceableFailure diag = emitSilenceableError()
373 << "failed to bufferize operation";
374 diag.attachNote(op->getLoc()) << "target payload op";
375 return diag;
376 }
377 allocatedBuffers.push_back(buffer);
378 }
379
380 // Set results.
381 results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
382 results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
384}
385
386void transform::BufferizeToAllocationOp::getEffects(
388 if (getBufferizeDestinationOnly()) {
389 // The destination is replaced with a newly allocated buffer, but the op
390 // itself remains in place.
391 onlyReadsHandle(getTargetMutable(), effects);
392 } else {
393 consumesHandle(getTargetMutable(), effects);
394 }
395 producesHandle(getOperation()->getOpResults(), effects);
396 modifiesPayload(effects);
397}
398
399LogicalResult transform::BufferizeToAllocationOp::verify() {
400 if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
401 getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
402 return emitOpError() << "unsupported memcpy op";
403 if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
404 return emitOpError() << "unsupported alloc op";
405 return success();
406}
407
408//===----------------------------------------------------------------------===//
409// PromoteTensorOp
410//===----------------------------------------------------------------------===//
411
412/// Return true if the operand may be read from by its owner. This is currently
413/// very conservative and only looks inside linalg operations to prevent
414/// unintentional data loss.
415static bool mayBeRead(OpOperand &operand) {
416 auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
417
418 // Be conservative about ops we cannot analyze deeper.
419 if (!linalgOp)
420 return true;
421
422 // Look inside linalg ops.
423 Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
424 return !blockArgument.use_empty();
425}
426
427/// Return true if the value may be read through any of its uses.
428static bool mayBeRead(Value value) {
429 // If the value has a reference semantics, it
430 // may be read through any alias...
431 if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
432 return true;
433 return llvm::any_of(value.getUses(),
434 static_cast<bool (&)(OpOperand &)>(mayBeRead));
435}
436
438transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
441 SmallVector<Value> promoted;
442 for (Value tensor : state.getPayloadValues(getTensor())) {
443 auto type = dyn_cast<RankedTensorType>(tensor.getType());
444 if (!type) {
445 return emitSilenceableError() << "non-tensor type: " << tensor;
446 }
447
448 Operation *definingOp = tensor.getDefiningOp();
449 if (definingOp)
450 rewriter.setInsertionPointAfter(definingOp);
451 else
452 rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
453
454 // Check this before we emit operations using this value.
455 bool needsMaterialization = mayBeRead(tensor);
456
457 SmallVector<Value> dynamicDims;
459 for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
460 if (!ShapedType::isDynamic(dim))
461 continue;
462 Value cst =
463 arith::ConstantIndexOp::create(rewriter, tensor.getLoc(), pos);
464 auto dimOp =
465 tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst);
466 preservedOps.insert(dimOp);
467 dynamicDims.push_back(dimOp);
468 }
469 auto allocation = bufferization::AllocTensorOp::create(
470 rewriter, tensor.getLoc(), type, dynamicDims);
471 // Set memory space if provided.
472 if (getMemorySpaceAttr())
473 allocation.setMemorySpaceAttr(getMemorySpaceAttr());
474 Value allocated = allocation;
475
476 // Only insert a materialization (typically bufferizes to a copy) when the
477 // value may be read from.
478 if (needsMaterialization) {
479 auto copy = bufferization::MaterializeInDestinationOp::create(
480 rewriter, tensor.getLoc(), tensor, allocated);
481 preservedOps.insert(copy);
482 promoted.push_back(copy.getResult());
483 } else {
484 promoted.push_back(allocated);
485 }
486 rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
487 }
488 results.setValues(cast<OpResult>(getPromoted()), promoted);
490}
491
492void transform::PromoteTensorOp::getEffects(
494 transform::onlyReadsHandle(getTensorMutable(), effects);
495 transform::producesHandle(getOperation()->getOpResults(), effects);
497}
498
499//===----------------------------------------------------------------------===//
500// DecomposeOp
501//===----------------------------------------------------------------------===//
502
504transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
505 LinalgOp target,
508#define DOWNSCALE(trans) \
509 { \
510 FailureOr<LinalgOp> res = tryApply<trans>(target); \
511 if (succeeded(res)) { \
512 results.push_back(*res); \
513 return DiagnosedSilenceableFailure::success(); \
514 } \
515 }
516
517#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
518#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
519
520 DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
521 DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
522 DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
523 DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
524 DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
525 DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
526 DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
527 DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
528 DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
531#undef DOWNSCALE_NORMAL
532#undef DOWNSCALE_CALL
533#undef DOWNSCALE
534 return emitDefaultSilenceableFailure(target);
535}
536
537//===----------------------------------------------------------------------===//
538// DecomposeInterfaceOp
539//===----------------------------------------------------------------------===//
540
541// Decompose the target operation if it implements the AggregatedOpInterface.
542// Push the decomposed operations (the ones that replaces the values produced by
543// \p target) in the `results`.
544DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
548 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
549 if (!decomposableOp) {
551 "payload is not a decomposable op"));
552 return emitDefaultSilenceableFailure(target);
553 }
554
555 FailureOr<SmallVector<Value>> maybeNewResults =
556 decomposableOp.decomposeOperation(rewriter);
557 if (failed(maybeNewResults))
558 return emitDefaultSilenceableFailure(target);
559
560 rewriter.replaceOp(decomposableOp, *maybeNewResults);
561 for (Value val : *maybeNewResults) {
562 Operation *definition = val.getDefiningOp();
563 if (definition)
564 results.push_back(definition);
565 }
567}
568
569//===----------------------------------------------------------------------===//
570// EliminateLinalgOpAnchoredEmptyTensorsOp
571//===----------------------------------------------------------------------===//
572
573void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
575 onlyReadsHandle(getTargetMutable(), effects);
576 modifiesPayload(effects);
577}
578
580transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
581 transform::TransformRewriter &rewriter, TransformResults &transformResults,
582 TransformState &state) {
584 options.allowReturnAllocsFromLoops = true;
585
586 for (Operation *target : state.getPayloadOps(getTarget())) {
588 if (failed(analyzeOp(target, state)))
589 return mlir::emitSilenceableFailure(target->getLoc())
590 << "failed to analyze op";
592 rewriter, target, state)))
593 return mlir::emitSilenceableFailure(target->getLoc())
594 << "failed to eliminate LinalgOp anchored tensor.empty ops";
595 }
597}
598
599//===----------------------------------------------------------------------===//
600// FuseOp
601//===----------------------------------------------------------------------===//
602
603void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
604 TypeRange loopTypes, Value target,
605 ArrayRef<int64_t> staticTileSizes,
606 ArrayRef<int64_t> staticTileInterchange,
607 bool applyCleanup, bool useForall) {
608 return build(
609 builder, result, loopTypes,
610 /*target=*/target,
611 /*mixedTileSizes=*/
612 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
613 /*mixedTileInterchange=*/
614 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
615 applyCleanup, useForall);
616}
617
618void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
619 Value target, ArrayRef<int64_t> staticTileSizes,
620 ArrayRef<int64_t> staticTileInterchange,
621 bool applyCleanup, bool useForall) {
622 return build(
623 builder, result,
624 /*target=*/target,
625 /*mixedTileSizes=*/
626 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
627 /*mixedTileInterchange=*/
628 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
629 applyCleanup, useForall);
630}
631
632void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
634 ArrayRef<OpFoldResult> mixedTileSizes,
635 ArrayRef<OpFoldResult> mixedTileInterchange,
636 bool applyCleanup, bool useForall) {
637 // Loop types are automaticaly splat by the callee, setting up one is
638 // enough.
639 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
640 build(builder, result, loopTypes, target, mixedTileSizes,
641 mixedTileInterchange, applyCleanup, useForall);
642}
643
644void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
645 TypeRange loopTypes, Value target,
646 ArrayRef<OpFoldResult> mixedTileSizes,
647 ArrayRef<OpFoldResult> mixedTileInterchange,
648 bool applyCleanup, bool useForall) {
649 SmallVector<int64_t> staticTileSizes;
650 SmallVector<Value> dynamicTileSizes;
651 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
652 SmallVector<int64_t> staticTileInterchange;
653 SmallVector<Value> dynamicTileInterchange;
654 dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
655 staticTileInterchange);
656 // Call the default builder which sets up the proper operands segment sizes
657 // attributes for multiple variadic operands. In the absence of this,
658 // horrible bugs ensue.
659 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
660 auto staticTileInterchangeAttr =
661 builder.getDenseI64ArrayAttr(staticTileInterchange);
662 unsigned numExpectedLoops =
663 useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
664 SmallVector<Type> resultTypes;
665 resultTypes.reserve(numExpectedLoops);
666 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
667 "expected one loop type or as many as loops");
668 if (loopTypes.size() == 1)
669 resultTypes.append(numExpectedLoops, loopTypes[0]);
670 else
671 llvm::append_range(resultTypes, loopTypes);
672 build(builder, result, /*transformed=*/target.getType(),
673 /*loops=*/resultTypes,
674 /*target=*/target,
675 /*tile_sizes=*/dynamicTileSizes,
676 /*tile_interchange=*/dynamicTileInterchange,
677 /*static_tile_sizes=*/staticTileSizesAttr,
678 /*static_tile_interchange=*/staticTileInterchangeAttr,
679 /*apply_cleanup=*/applyCleanup,
680 /*use_forall=*/useForall);
681}
682
683/// Apply a tiling transformation to all payload ops and store both the
684/// tiled operation as well as the created tile loops.
685template <typename Range>
686static LogicalResult applyTilingToAll(
687 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
688 unsigned numLoops, transform::TransformResults &transformResults,
689 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
690 applyFn) {
691 SmallVector<Operation *> tiledLinalgOps;
692 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
693
694 for (Operation *target : payloadOps) {
695 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
696 if (!tilingInterfaceOp)
697 return transformOp->emitError("only TilingInterface ops are supported");
698
699 rewriter.setInsertionPoint(target);
700 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
701 applyFn(tilingInterfaceOp);
702 if (failed(tiledResults))
703 return failure();
704
705 // Perform the replacement of tiled and fused values.
706 SmallVector<Operation *> opsToReplace{target};
707 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
708 for (Operation *toReplace : opsToReplace) {
709 for (OpResult res : toReplace->getResults())
710 if (auto replacement = tiledResults->replacements.lookup(res))
711 rewriter.replaceAllUsesWith(res, replacement);
712 if (toReplace->use_empty()) {
713 rewriter.eraseOp(toReplace);
714 }
715 }
716
717 // Report back the relevant handles to the transform op.
718 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
719 assert(tiledResults->loops.size() == numLoops &&
720 "Mismatched number of loops, tile and fuse transform should have "
721 "failed");
722 for (unsigned int i = 0; i < numLoops; ++i)
723 loopOps[i].push_back(tiledResults->loops[i]);
724 }
725
726 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
727 for (unsigned int i = 0; i < numLoops; ++i)
728 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
729
730 return success();
731}
732
734transform::FuseOp::apply(transform::TransformRewriter &rewriter,
735 mlir::transform::TransformResults &transformResults,
737 auto transformOp = cast<TransformOpInterface>(getOperation());
738
739 SmallVector<int64_t> tileSizes;
741 state, transformOp, getMixedTileSizes(), tileSizes);
742 if (!status.succeeded())
743 return status;
744 SmallVector<int64_t> tileInterchange;
746 state, transformOp, getMixedTileInterchange(), tileInterchange);
747 if (!status.succeeded())
748 return status;
749
750 scf::SCFTilingOptions tilingOptions;
751 tilingOptions.interchangeVector = tileInterchange;
752 bool useForall = getUseForall();
753 tilingOptions.setLoopType(useForall
754 ? scf::SCFTilingOptions::LoopType::ForallOp
755 : scf::SCFTilingOptions::LoopType::ForOp);
756 SmallVector<OpFoldResult> tileSizesOfr =
757 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
758 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
759 scf::SCFTileAndFuseOptions tileAndFuseOptions;
760 tileAndFuseOptions.tilingOptions = tilingOptions;
761
762 if (getApplyCleanup()) {
763 MLIRContext *context = rewriter.getContext();
764 RewritePatternSet patterns(context);
765 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
768 tileAndFuseOptions.cleanupPatterns = std::move(patterns);
769 }
770
771 size_t numLoops =
772 useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
773 LogicalResult result = applyTilingToAll(
774 rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
775 transformResults,
776 [&](TilingInterface tilingInterfaceOp)
777 -> FailureOr<scf::SCFTileAndFuseResult> {
778 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
779 tileAndFuseOptions);
780 });
783}
784
785LogicalResult transform::FuseOp::verify() {
786 auto iterspace_rank = getStaticTileSizes().size();
787 ArrayRef<int64_t> permutation = getStaticTileInterchange();
788 if (permutation.size() > iterspace_rank)
789 return emitOpError()
790 << "interchange length exceeds iteration space dimensions ("
791 << iterspace_rank << "), found " << getTileInterchange();
792 SmallVector<bool> seen(iterspace_rank, false);
793 for (int64_t v : permutation) {
794 if (!ShapedType::isDynamic(v)) {
795 if (v < 0 || v >= static_cast<int64_t>(iterspace_rank))
796 return emitOpError() << "expects interchange values to be in range [0, "
797 << iterspace_rank << "), found: " << v;
798 if (seen[v])
799 return emitOpError() << "found duplicate interchange value: " << v;
800 seen[v] = true;
801 }
802 }
803
804 ArrayRef<int64_t> sizes = getStaticTileSizes();
805 size_t numExpectedLoops =
806 getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
807 if (numExpectedLoops != getNumResults() - 1)
808 return emitOpError() << "expects " << numExpectedLoops << " loop results";
809
810 return success();
811}
812
813SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
814 return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext());
815}
816
817SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
818 return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
819 getContext());
820}
821
822void transform::FuseOp::getEffects(
824 consumesHandle(getTargetMutable(), effects);
825 onlyReadsHandle(getTileSizesMutable(), effects);
826 onlyReadsHandle(getTileInterchangeMutable(), effects);
827 producesHandle(getOperation()->getOpResults(), effects);
828 modifiesPayload(effects);
829}
830
831//===----------------------------------------------------------------------===//
832// FuseIntoContainingOp
833//===----------------------------------------------------------------------===//
834
835void transform::FuseIntoContainingOp::build(OpBuilder &builder,
837 Value producerOp,
838 Value containingOp) {
839 result.addOperands({producerOp, containingOp});
840 auto resultType = transform::AnyOpType::get(builder.getContext());
841 result.addTypes({resultType, resultType});
842}
843
844/// Add new operands to the forall op for users of the producerOp
845/// that are dominated by the containing scf.forall op.
847 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
848 Operation *containingOp, TilingResult &tileAndFuseResult,
849 int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
851
852 // Count number of users not including the containing op
853 SetVector<Operation *> dominatedUsers;
854 DominanceInfo domInfo(containingOp);
855 for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
856 if (!containingOp->isAncestor(user) &&
857 (domInfo.dominates(containingOp, user))) {
858 dominatedUsers.insert(user);
859 }
860 }
861 if (dominatedUsers.empty())
862 return nullptr;
863
864 // Create new scf.forall op
865 auto forallOp = cast<scf::ForallOp>(containingOp);
866 OpBuilder::InsertionGuard g(rewriter);
867 rewriter.setInsertionPoint(forallOp);
868
869 // Get new output
870 Location loc = forallOp.getLoc();
871 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
872 if (!genericOp)
873 return nullptr;
874 SmallVector<Value> outputs = genericOp.getOutputs();
875 SmallVector<Value> newOuts(forallOp.getOutputs());
876 newOuts.push_back(outputs[resultNumber]);
877
878 // Create new scf.forall op
879 auto newforallOp = scf::ForallOp::create(
880 rewriter, loc, forallOp.getMixedLowerBound(),
881 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
882 forallOp.getMapping());
883 rewriter.eraseBlock(newforallOp.getBody());
884 newforallOp.getRegion().takeBody(forallOp.getRegion());
885
886 // Add additional block argument for new value being returned
887 // and replaces all uses of the new output with corresponding bbArg
888 // inside the scf.forall to enable fusion into this new scf.forall.
889 newforallOp.getBody()->addArgument(newOuts.back().getType(),
890 newOuts.back().getLoc());
891 auto bbArgs = newforallOp.getBody()->getArguments();
892 rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
893 [&](OpOperand &use) {
894 Operation *op = use.getOwner();
895 return newforallOp->isProperAncestor(op);
896 });
897
898 // Fix terminator
899 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
900 SmallVector<Operation *> yieldingOps = llvm::map_to_vector<4>(
901 terminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
902 Operation *firstYieldOp = yieldingOps.front();
903 rewriter.setInsertionPoint(firstYieldOp);
904 Value src = tileAndFuseResult.tiledValues[0];
905 Value dst = newforallOp.getRegionIterArgs().back();
906 SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
907 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->getLoc(), src,
908 dst, offsets, sizes, strides);
909
910 for (auto result : llvm::enumerate(forallOp.getResults())) {
911 rewriter.replaceAllUsesWith(result.value(),
912 newforallOp->getResult(result.index()));
913 }
914 rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
915 newforallOp->getResults().back(),
916 [&](OpOperand &use) {
917 Operation *user = use.getOwner();
918 return dominatedUsers.contains(user);
919 });
920 return newforallOp;
921}
922
923/// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
924/// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
925/// outer loop. To determine the second condition, this function iterates
926/// using a worklist over the enclosing loops, trying to find 'src' in any of
927/// the parent loop's iter args.
928static bool sameOrEquivalentIterArg(Value src, Value dst) {
929 // Stack like vector containing possible iterArgs candidates. The first one
930 // is dst, and we will transverse the IR from there.
931 SmallVector<Value> destWorklist;
932 destWorklist.push_back(dst);
933
934 while (!destWorklist.empty()) {
935 Value currentDst = destWorklist.pop_back_val();
936
937 // We have found the same operand in some iter arg in the loop structure,
938 // so src and dst are equivalent.
939 if (src == currentDst)
940 return true;
941
942 // The operands are not equivalent, look for enclosing loops over
943 // currentDst.
944 auto bbArg = dyn_cast<BlockArgument>(currentDst);
945 if (!bbArg)
946 continue;
947
948 Block *parentBlock = bbArg.getOwner();
949 assert(parentBlock && "unlinked block argument");
950
951 Operation *parentOp = parentBlock->getParentOp();
952 assert(parentOp && "expected block argument with parent operation");
953
954 // Check if parent is loop-like. If it's not, do not add it to the worklist.
955 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
956 if (!parentLoop)
957 continue;
958
959 for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
960 // No need to check for null as innerIterArg is tied to parentLoop.
961 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
962 Value loopBlockArgument =
963 parentLoop->getOperand(operand->getOperandNumber());
964 destWorklist.push_back(loopBlockArgument);
965 }
966 }
967
968 return false;
969}
970
971/// Find the first "extract" user of `producerOp` and tile it right before its
972/// use. The tiled op is fused under the `containingOp`.
973/// Return this fused op on success or nullptr if anything fails.
974/// If tiled op has uses that are dominated by `containingOp`, return
975/// a new `containingOp` with results of the fused op appended to
976/// results of the `containingOp` or nullptr if there are no dominated uses.
977static std::tuple<SmallVector<Operation *>, Operation *>
979 Operation *producerOp, Operation *containingOp) {
980 LDBG() << "Try to fuse a direct extract use";
981 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
982 if (!tileableProducer) {
983 diag.attachNote(producerOp->getLoc())
984 << "producer is not a TileableInterface: " << *producerOp;
985 return {};
986 }
987
988 // Search the producer slices accessed within the containing operation.
989 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
990 // evolve into an interface.
991 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
992 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
993 return sliceOp && containingOp->isProperAncestor(sliceOp);
994 });
995
996 // Find a fusion opportunity.
997 if (it == tileableProducer->getUsers().end()) {
998 diag.attachNote(tileableProducer->getLoc())
999 << "could not find fusion opportunity for: " << *tileableProducer;
1000 return {};
1001 }
1002 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
1003
1004 // Try to fuse the producer in-place.
1005 OpBuilder::InsertionGuard guard(rewriter);
1006 rewriter.setInsertionPoint(sliceOpToTile);
1007
1008 // Clone the producer inside the consumer and try to update the producer init
1009 // operands using the loop bbArgs if applicable. More precisely, if the bbArg
1010 // of the container loop points to a value that it is used by the consumer op,
1011 // then, instead of using such value on the consumer, use the value coming
1012 // from the bbArg instead. This allows to reuse the output tensor (instead of
1013 // creating a new one) of the container when both producer and container write
1014 // to the same output.
1015 if (LoopLikeOpInterface containerLoop =
1016 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
1017 Operation *clone = rewriter.clone(*producerOp);
1018 rewriter.modifyOpInPlace(clone, [&]() {
1019 // Iterate over the outputs of the producer and over the loop bbArgs and
1020 // check if any bbArg points to the same value as the producer output. In
1021 // such case, make the producer output point to the bbArg directly.
1022 auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone);
1023 if (!dpsInterface)
1024 return;
1025
1026 for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
1027 Value producerOperand =
1028 clone->getOperand(initOperandPtr.getOperandNumber());
1029 for (BlockArgument containerIterArg :
1030 containerLoop.getRegionIterArgs()) {
1031 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
1032 Value consumerOperand =
1033 containerLoop->getOperand(bbArg->getOperandNumber());
1034 // The producer has the same init as the loop bbArg, use it.
1035 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
1036 initOperandPtr.set(containerIterArg);
1037 }
1038 }
1039 }
1040 });
1041
1042 tileableProducer = dyn_cast<TilingInterface>(clone);
1043 }
1044
1045 // Tile the producer.
1046 int64_t resultNumber =
1047 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
1048 LDBG() << "resultNumber: " << resultNumber;
1049
1050 SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
1051 SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
1052
1053 FailureOr<TilingResult> tileAndFuseResult =
1054 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
1055 sizes);
1056
1057 if (failed(tileAndFuseResult)) {
1058 diag.attachNote(tileableProducer->getLoc())
1059 << "failed to tile producer op: " << *tileableProducer;
1060 return {};
1061 }
1062
1063#ifndef NDEBUG
1064 for (auto *tiledOp : tileAndFuseResult->tiledOps) {
1065 LDBG() << "tiledProducer: " << *tiledOp;
1066 }
1067#endif
1068
1069 // Replace the extract op.
1070 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1071 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1072 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1073 if (failed(maybeRankReduced)) {
1074 diag.attachNote(producerOp->getLoc())
1075 << "shape types don't match (missing canonicalization?):\nTiledOp: "
1076 << tileAndFuseResult->tiledValues[0]
1077 << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
1078 return {};
1079 }
1080 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
1081
1082 // Add new outputs to containing op, if required
1083 Operation *newContainingOp = replaceForAllWithNewSignature(
1084 rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
1085 resultNumber, offsets, sizes);
1086
1087 // Cleanup clone.
1088 if (isa<LoopLikeOpInterface>(containingOp))
1089 rewriter.eraseOp(tileableProducer);
1090
1091 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
1092}
1093
1094/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
1095/// it is exactly the `containingOp`, otherwise bail.
1096/// Then, find the first "extract" user of the tied block argument and tile it
1097/// right before its "extract" use. The tiled op is fused under the
1098/// `containingOp`.
1099/// Return this fused op on success or nullptr if anything fails.
1102 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
1103 Operation *containingOp) {
1104 LDBG() << "Try to fuse an extract use through block argument";
1105
1106 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
1107 if (!tileableProducer) {
1108 diag.attachNote(producerOp->getLoc())
1109 << "producer is not a TileableInterface: " << *producerOp;
1110 return {};
1111 }
1112
1113 // Search the first use by a "scf::ForallOp" user.
1114 scf::ForallOp forallOp;
1115 auto itProducerUses =
1116 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
1117 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
1118 return forallOp;
1119 });
1120 // If it's not from the containing op, return.
1121 if (!forallOp || forallOp != containingOp) {
1122 diag.attachNote(tileableProducer->getLoc())
1123 << "could not find a use by the containing op: " << *tileableProducer;
1124 return {};
1125 }
1126
1127 // Search the producer slices accessed within the containing
1128 // operation.
1129 // TODO: Generalize to more extract/insert/parallel_insert triples.
1130 // Maybe evolve into an interface.
1131 OpOperand *pUse = &(*itProducerUses);
1132 BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
1133
1134 // Search the producer slices accessed within the containing operation.
1135 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
1136 // evolve into an interface.
1137 auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
1138 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1139 return sliceOp && containingOp->isProperAncestor(sliceOp);
1140 });
1141
1142 // Find a fusion opportunity.
1143 if (itBBArgUsers == bbArg.getUsers().end()) {
1144 diag.attachNote(containingOp->getLoc())
1145 << "could not find fusion opportunity for bbArg: " << bbArg;
1146 return {};
1147 }
1148 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
1149
1150 // Try to fuse the producer in-place.
1151 OpBuilder::InsertionGuard guard(rewriter);
1152 rewriter.setInsertionPoint(sliceOpToTile);
1153
1154 // Replace the use in the tileableProducer before tiling: clone, replace and
1155 // then tile.
1156 int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
1157 LDBG() << "resultNumber: " << resultNumber;
1158
1159 // Gather destination tensors.
1160 SmallVector<Value> destinationTensors;
1162 rewriter, tileableProducer->getLoc(), tileableProducer,
1163 destinationTensors))) {
1164 diag.attachNote(tileableProducer->getLoc())
1165 << "failed to get destination tensors for: " << *tileableProducer;
1166 return {};
1167 }
1168
1169 IRMapping bvm;
1170 bvm.map(destinationTensors[resultNumber], bbArg);
1171 auto tileableProducerClone =
1172 cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
1173 llvm::scope_exit scopeGuard(
1174 [&]() { rewriter.eraseOp(tileableProducerClone); });
1175
1176 // Tile the producer.
1177 FailureOr<TilingResult> tileAndFuseResult =
1178 tileableProducerClone.generateResultTileValue(
1179 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
1180 sliceOpToTile.getMixedSizes());
1181 if (failed(tileAndFuseResult)) {
1182 diag.attachNote(tileableProducer->getLoc())
1183 << "failed to tile producer op: " << *tileableProducer;
1184 return {};
1185 }
1186
1187 // Replace the extract op.
1188 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1189 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1190 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1191 assert(succeeded(maybeRankReduced) && "unexpected shape");
1192 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
1193
1194 // Replace the use in containingOp.
1195 rewriter.modifyOpInPlace(containingOp, [&]() {
1196 containingOp->setOperand(pUse->getOperandNumber(),
1197 destinationTensors.front());
1198 });
1199
1200 return tileAndFuseResult->tiledOps;
1201}
1202
1204 Operation *producerOp,
1205 Operation *containingOp) {
1206 LDBG() << "Try to fuse an use by cloning";
1207
1208 // Gather all uses inside the containing op.
1210 for (OpResult result : producerOp->getOpResults()) {
1211 for (OpOperand &use : result.getUses()) {
1212 if (containingOp->isProperAncestor(use.getOwner())) {
1213 uses.push_back(&use);
1214 continue;
1215 }
1216 // Cannot clone and fuse if the use is by the containing op itself: fail
1217 // immediately.
1218 if (containingOp == use.getOwner()) {
1219 diag.attachNote(producerOp->getLoc())
1220 << "producer op use by containing op cannot be fused by cloning";
1221 return nullptr;
1222 }
1223 }
1224 }
1225
1226 // Check for a non-empty list of fusion opportunities.
1227 if (uses.empty()) {
1228 diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
1229 return nullptr;
1230 }
1231
1232 // Clone and fuse inside the containing op.
1233 Operation *fusedOp = nullptr;
1234 OpOperand *use = uses.front();
1235 // Parallel insert slice is not a valid clone destination.
1236 // TODO: Generalize to other type of ops.
1237 assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
1238 "Parallel insert slice is not a valid clone destination");
1239 unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
1240 LDBG() << "resultNumber: " << resultNumber;
1241
1242 OpBuilder::InsertionGuard guard(rewriter);
1243 rewriter.setInsertionPoint(use->getOwner());
1244 fusedOp = rewriter.clone(*producerOp);
1245 rewriter.modifyOpInPlace(
1246 use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1247
1248 return fusedOp;
1249}
1250
1251bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1252 // Allow repeated handles since we are fusing everything anyway.
1253 return true;
1254}
1255
1257transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
1260 SmallVector<Operation *> fusedOps;
1261 auto producerOps = state.getPayloadOps(getProducerOp());
1262 auto containingOps = state.getPayloadOps(getContainingOp());
1263 if (!llvm::hasSingleElement(containingOps)) {
1264 return emitDefiniteFailure()
1265 << "requires exactly one containing_op handle (got "
1266 << llvm::range_size(containingOps) << ")";
1267 }
1268 Operation *containingOp = *containingOps.begin();
1269
1270 // If nothing to fuse, propagate success.
1271 if (std::empty(producerOps)) {
1272 results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
1273 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1275 }
1276
1277 // Helper function to find the next producer that should be fused. Take any
1278 // producer that has a use inside the containing op.
1279 SetVector<Operation *> remainingProducers(llvm::from_range, producerOps);
1280 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1281 for (const auto &it : enumerate(remainingProducers)) {
1282 Operation *producerOp = it.value();
1283 // The containing op may be a user of producerOp: use isAncestor.
1284 int64_t numUsesInContainingOp =
1285 llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
1286 return containingOp->isAncestor(op);
1287 });
1288 // TODO: When resolving the TODO below (no duplicate ops), take an op
1289 // that has no use among the remaining producers. This is a topological
1290 // sorting.
1291 if (numUsesInContainingOp > 0) {
1292 if (numUsesInContainingOp == 1)
1293 remainingProducers.erase(remainingProducers.begin() + it.index());
1294 return producerOp;
1295 }
1296 }
1297 return failure();
1298 };
1299
1300 while (!remainingProducers.empty()) {
1301 auto nextProducer = getNextProducer();
1302 if (failed(nextProducer)) {
1303 auto diag = mlir::emitSilenceableFailure(getLoc())
1304 << "could not find next producer to fuse into container";
1305 diag.attachNote(containingOp->getLoc()) << "containing op";
1306 return diag;
1307 }
1308
1309 Operation *producerOp = *nextProducer;
1310
1311 // Default diagnostic, to be complemented with more failure information.
1313 diag << "could not fuse " << *producerOp << " into " << *containingOp;
1314
1315 // TODO: If there are multiple uses of the producer in the containing op,
1316 // we currently tile/clone the op multiple times (once per use). In some
1317 // cases, we can tile/clone once and reuse the value for each use.
1318 // Futhermore, producers should then be traversed according to a
1319 // topological sorting.
1320 auto [tiledOps, newContainingOp] =
1321 tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1322 if (!tiledOps.empty()) {
1323 LDBG() << "\nFused a direct extract use\n" << *containingOp;
1324 fusedOps.append(tiledOps);
1325 if (newContainingOp) {
1326 // Update handles associated with the containing op so we don't need to
1327 // invalidate them. This is a hack to support better composability
1328 // between tiling and fusion while a proper mechanism is being
1329 // investigated.
1330 //
1331 // DO NOT replicate this elsewhere unless you understand what you are
1332 // doing.
1333 LogicalResult replacementStatus =
1334 rewriter.notifyPayloadOperationReplaced(containingOp,
1335 newContainingOp);
1336 (void)replacementStatus;
1337 assert(succeeded(replacementStatus) &&
1338 "unable to update transform state mapping");
1339 rewriter.eraseOp(containingOp);
1340 containingOp = newContainingOp;
1341 }
1342 continue;
1343 }
1344
1345 SmallVector<Operation *> tiledContainingOpOperand =
1347 rewriter, diag, producerOp, containingOp);
1348 if (!tiledContainingOpOperand.empty()) {
1349 LDBG() << "\nFused an extract use through block argument\n"
1350 << *containingOp;
1351 fusedOps.append(tiledContainingOpOperand);
1352 continue;
1353 }
1354
1355 Operation *cloned =
1356 cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1357 if (cloned) {
1358 LDBG() << "\nFused an use by cloning\n" << *containingOp;
1359 fusedOps.push_back(cloned);
1360 continue;
1361 }
1363 }
1364
1365 results.set(cast<OpResult>(getFusedOp()), fusedOps);
1366 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1368}
1369
1370void transform::FuseIntoContainingOp::getEffects(
1372 consumesHandle(getProducerOpMutable(), effects);
1373 onlyReadsHandle(getContainingOpMutable(), effects);
1374 producesHandle(getOperation()->getOpResults(), effects);
1375 modifiesPayload(effects);
1376}
1377
1378//===----------------------------------------------------------------------===//
1379// GeneralizeOp
1380//===----------------------------------------------------------------------===//
1381
1383transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1384 LinalgOp target,
1387 // Exit early if no transformation is needed.
1388 if (isa<GenericOp>(target)) {
1389 results.push_back(target);
1391 }
1392 rewriter.setInsertionPoint(target);
1393 FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1394 if (succeeded(generic)) {
1395 results.push_back(generic->getOperation());
1397 }
1398 return emitDefaultSilenceableFailure(target);
1399}
1400
1401//===----------------------------------------------------------------------===//
1402// SpecializeOp
1403//===----------------------------------------------------------------------===/
1404
1406transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1407 LinalgOp target,
1410 // Exit early if the operation is not a generic.
1411 if (!isa<GenericOp>(target)) {
1412 results.push_back(target);
1414 }
1415 rewriter.setInsertionPoint(target);
1416 FailureOr<LinalgOp> named =
1417 specializeGenericOp(rewriter, cast<GenericOp>(target));
1418 if (succeeded(named)) {
1419 results.push_back(named->getOperation());
1421 }
1422 return emitDefaultSilenceableFailure(target);
1423}
1424
1425//===----------------------------------------------------------------------===//
1426// InterchangeOp
1427//===----------------------------------------------------------------------===//
1428
1430transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1431 GenericOp target,
1434 ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1435 // Exit early if no transformation is needed.
1436 if (interchangeVector.empty()) {
1437 results.push_back(target);
1439 }
1440
1441 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1442 if (interchangeVector.size() != numLoops) {
1443 return emitSilenceableError()
1444 << getIteratorInterchangeAttrName() << " has length ("
1445 << interchangeVector.size()
1446 << ") different from the number of loops in the target operation ("
1447 << numLoops << ")";
1448 }
1449 FailureOr<GenericOp> res = interchangeGenericOp(
1450 rewriter, target, SmallVector<unsigned>(interchangeVector));
1451 if (failed(res))
1452 return emitDefiniteFailure() << "failed to apply";
1453 results.push_back(res->getOperation());
1455}
1456
1457LogicalResult transform::InterchangeOp::verify() {
1458 ArrayRef<int64_t> permutation = getIteratorInterchange();
1459 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1460 if (!std::is_permutation(sequence.begin(), sequence.end(),
1461 permutation.begin(), permutation.end())) {
1462 return emitOpError()
1463 << "expects iterator_interchange to be a permutation, found "
1464 << getIteratorInterchange();
1465 }
1466 return success();
1467}
1468
1469//===----------------------------------------------------------------------===//
1470// LinalgCopyToMemrefOp
1471//===----------------------------------------------------------------------===//
1472
1473DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1474 transform::TransformRewriter &rewriter, Operation *targetOp,
1477
1478 // Check if the target can be converted.
1479 if (!isa<linalg::CopyOp>(targetOp)) {
1481 emitSilenceableError() << "only linalg.copy target ops are supported";
1482 diag.attachNote(targetOp->getLoc()) << "target op";
1483 return diag;
1484 }
1485
1486 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1487 if (!copyOp.hasPureBufferSemantics()) {
1489 emitSilenceableError()
1490 << "cannot transform a linalg.copy on tensors into a memref.copy";
1491 diag.attachNote(targetOp->getLoc()) << "target op";
1492 return diag;
1493 }
1494
1495 SmallVector<Value> inputs = copyOp.getInputs();
1496 SmallVector<Value> outputs = copyOp.getOutputs();
1497 assert(inputs.size() == 1 && "expected linalg copy op with one input");
1498 assert(outputs.size() == 1 && "expected memref copy op with one output");
1499 Value input = inputs.front();
1500 Value output = outputs.front();
1501
1502 // linalg.copy supports different element types on source/dest whereas
1503 // memref.copy does not, so we must check that the source and dest types can
1504 // be handled by memref.copy and otherwise reject the transformation.
1505 if (!isa<ShapedType>(input.getType())) {
1507 emitSilenceableError()
1508 << "cannot transform a linalg.copy which input has no shape";
1509 diag.attachNote(targetOp->getLoc()) << "target op";
1510 return diag;
1511 }
1512
1513 // linalg.copy destination must be a shaped type.
1514 assert(isa<ShapedType>(output.getType()));
1515
1516 if (cast<ShapedType>(input.getType()).getElementType() !=
1517 cast<ShapedType>(output.getType()).getElementType()) {
1519 emitSilenceableError()
1520 << "cannot transform a linalg.copy with different source and "
1521 "destination element types ";
1522 diag.attachNote(targetOp->getLoc()) << "target op";
1523 return diag;
1524 }
1525
1526 // Target can be converted, do it.
1527 auto memrefCopyOp =
1528 rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1529
1530 results.push_back(memrefCopyOp);
1532}
1533
1534//===----------------------------------------------------------------------===//
1535// LowerPackOp
1536//===----------------------------------------------------------------------===//
1537
1538DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1539 transform::TransformRewriter &rewriter, linalg::PackOp target,
1540 transform::ApplyToEachResultList &transformResults,
1542 rewriter.setInsertionPoint(target);
1543 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1544 FailureOr<LowerPackResult> res =
1545 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1546 if (failed(res)) {
1547 return mlir::emitSilenceableFailure(target->getLoc())
1548 << "cannot lower to pad + expand + transpose";
1549 }
1550 transformResults.push_back(res->padOp);
1551 transformResults.push_back(res->expandShapeOp);
1552 transformResults.push_back(res->transposeOp);
1554}
1555
1556//===----------------------------------------------------------------------===//
1557// LowerUnPackOp
1558//===----------------------------------------------------------------------===//
1559
1560DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1561 transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1562 transform::ApplyToEachResultList &transformResults,
1564 rewriter.setInsertionPoint(target);
1565 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1566 FailureOr<LowerUnPackOpResult> res =
1567 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1568 if (failed(res)) {
1570 emitSilenceableError()
1571 << "cannot lower to transpose + collapse + extract";
1572 diag.attachNote(target->getLoc()) << "target payload op";
1573 return diag;
1574 }
1575 transformResults.push_back(res->emptyOp);
1576 transformResults.push_back(res->transposeOp);
1577 transformResults.push_back(res->collapseShapeOp);
1578 transformResults.push_back(res->extractSliceOp);
1579 transformResults.push_back(res->copyOp);
1581}
1582
1583//===---------------------------------------------------------------------===//
1584// MatchOp
1585//===---------------------------------------------------------------------===//
1586
1587void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1588 Value target, ArrayRef<StringRef> opNames) {
1589 result.addOperands(target);
1590 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1591 builder.getStrArrayAttr(opNames));
1592 result.addTypes(transform::AnyOpType::get(builder.getContext()));
1593}
1594
1595void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1596 TypeRange resultTypes, Value target,
1597 ArrayRef<StringRef> opNames) {
1598 result.addOperands(target);
1599 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1600 builder.getStrArrayAttr(opNames));
1601 result.addTypes(resultTypes);
1602}
1603
1605transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1608 llvm::StringSet<> strs;
1609 if (getOps().has_value())
1610 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1611
1612 auto payloadOps = state.getPayloadOps(getTarget());
1613 if (!llvm::hasSingleElement(payloadOps)) {
1614 return emitDefiniteFailure("requires exactly one target handle");
1615 }
1616
1618 bool incorrectNumOperandTypes = false;
1619 auto matchFun = [&](Operation *op) {
1620 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1621 return;
1622
1623 // Interfaces cannot be matched by name, just by ID.
1624 // So we specifically encode the interfaces we care about for this op.
1625 if (getInterface().has_value()) {
1626 auto iface = getInterface().value();
1627 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1628 !isa<LinalgOp>(op))
1629 return;
1630 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1631 !isa<TilingInterface>(op))
1632 return;
1633 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1634 !isa<LoopLikeOpInterface>(op))
1635 return;
1636 }
1637
1638 // Check if all specified attributes match.
1639 if (getOpAttrs().has_value()) {
1640 DictionaryAttr opAttrs = getOpAttrs().value();
1641 for (NamedAttribute attr : opAttrs) {
1642 if (attr.getName() == getInterfaceAttrName() ||
1643 attr.getName() == getOpsAttrName())
1644 continue;
1645 if (!op->hasAttr(attr.getName()))
1646 return;
1647 if (op->getAttr(attr.getName()) != attr.getValue())
1648 return;
1649 }
1650 }
1651
1652 if (getFilterResultType().has_value()) {
1653 Type t = getFilterResultType().value();
1654 if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1655 return;
1656 }
1657
1658 if (getFilterOperandTypes().has_value()) {
1659 mlir::ArrayAttr types = getFilterOperandTypes().value();
1660 auto operandTypes = op->getOperandTypes();
1661
1662 if (types.size() == 1) {
1663 // All the operands must must be equal to the specified type
1664 auto typeattr =
1665 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1666 Type t = cast<::mlir::Type>(typeattr.getValue());
1667 if (!llvm::all_of(op->getOperandTypes(),
1668 [&](Type operandType) { return operandType == t; }))
1669 return;
1670 } else {
1671 // The operand types must match all the types in the list (in the same
1672 // order in with they are specified)
1673 if (types.size() != operandTypes.size()) {
1674 incorrectNumOperandTypes = true;
1675 return;
1676 }
1677
1678 for (auto [attr, operandType] :
1679 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1680 auto typeattr = cast<mlir::TypeAttr>(attr);
1681 Type type = cast<::mlir::Type>(typeattr.getValue());
1682
1683 if (type != operandType)
1684 return;
1685 }
1686 }
1687 }
1688
1689 // All constraints are satisfied.
1690 res.push_back(op);
1691 return;
1692 };
1693
1694 (*payloadOps.begin())->walk(matchFun);
1695 if (incorrectNumOperandTypes)
1696 return emitDefiniteFailure("If filter_operand_types contains more than a "
1697 "type, then it must contain as much types as "
1698 "the number of operands in the target ops");
1699 results.set(cast<OpResult>(getResult()), res);
1701}
1702
1703//===---------------------------------------------------------------------===//
1704// MultiTileSizesOp
1705//===---------------------------------------------------------------------===//
1706
1708 Type targetType, Type lowSizeType, Type,
1709 Type) {
1710 printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1711}
1712
1713static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1714 Type &targetType, Type &lowSizeType,
1715 Type &highSizeType,
1716 Type &splitPointType) {
1717 FunctionType funcType;
1718 llvm::SMLoc typeLoc = parser.getCurrentLocation();
1719 if (failed(parser.parseType<FunctionType>(funcType)))
1720 return failure();
1721
1722 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1723 parser.emitError(typeLoc) << "expects a trailing functional type with one "
1724 "argument and one result";
1725 }
1726 targetType = funcType.getInput(0);
1727 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1728
1729 return success();
1730}
1731
1732DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1733 transform::TransformRewriter &rewriter, LinalgOp target,
1735 if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1736 if (target.hasDynamicShape()) {
1737 auto diag = emitSilenceableError()
1738 << "cannot compute parametric tile sizes for dynamically "
1739 "shaped payload op";
1740 diag.attachNote(target->getLoc()) << "payload op";
1741 return diag;
1742 }
1743
1744 FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1745 target, getDimension(), getTargetSize(), getDivisor());
1746 if (failed(spec)) {
1747 return emitSilenceableError()
1748 << "failed to compute multi-size tiling sizes";
1749 }
1750
1751 Builder builder(target.getContext());
1752 results.assign(llvm::map_range(
1753 ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1754 spec->lowTileSize * spec->lowTripCount}),
1755 [&builder, this](int64_t value) {
1756 return builder.getIntegerAttr(
1757 cast<ParamType>(getLowSize().getType()).getType(), value);
1758 }));
1760 }
1761
1762 OpBuilder builder(target.getContext());
1763 builder.setInsertionPoint(target);
1764 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1765 OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1766 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1767 builder, target, getDimension(), targetSize, divisor);
1768 if (failed(spec)) {
1769 return emitSilenceableError() << "could not generate tile size computation";
1770 }
1771
1772 AffineExpr s0 = builder.getAffineSymbolExpr(0);
1773 AffineExpr s1 = builder.getAffineSymbolExpr(1);
1774 Operation *splitPoint =
1775 affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1776 {spec->lowTileSize, spec->lowTripCount});
1777 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1778 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1779 assert(lowTileSize && highTileSize && splitPoint &&
1780 "tile sizes are not produced by operations");
1781 results.reserve(results.size() + 3);
1782 results.push_back(lowTileSize);
1783 results.push_back(highTileSize);
1784 results.push_back(splitPoint);
1786}
1787
1788void transform::MultiTileSizesOp::getEffects(
1790 onlyReadsHandle(getTargetMutable(), effects);
1791 producesHandle(getOperation()->getOpResults(), effects);
1792 if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1793 onlyReadsPayload(effects);
1794 else
1795 modifiesPayload(effects);
1796}
1797
1798LogicalResult transform::MultiTileSizesOp::verify() {
1799 if (getLowSize().getType() != getHighSize().getType() ||
1800 getLowSize().getType() != getSplitPoint().getType()) {
1801 return emitOpError() << "expects all results type to be the same";
1802 }
1803 return success();
1804}
1805
1806//===---------------------------------------------------------------------===//
1807// PackOp
1808//===---------------------------------------------------------------------===//
1809
1810void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1811 Value target,
1812 ArrayRef<OpFoldResult> mixedPackedSizes) {
1813 SmallVector<int64_t> staticPackedSizes;
1814 SmallVector<Value> dynamicPackedSizes;
1815 dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1816 staticPackedSizes);
1817 // Call the default builder which sets up the proper operands segment sizes
1818 // attributes for multiple variadic operands. In the absence of this, horrible
1819 // bugs ensue.
1820 Type linalgOpHType = transform::OperationType::get(
1821 builder.getContext(), GenericOp::getOperationName());
1822 build(builder, result,
1823 /*resultType=*/linalgOpHType,
1824 /*target=*/target,
1825 /*dynamic_sizes=*/dynamicPackedSizes,
1826 /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1827}
1828
1829SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1830 Builder b(getContext());
1831 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1832}
1833
1835transform::PackOp::apply(transform::TransformRewriter &rewriter,
1836 transform::TransformResults &transformResults,
1838 auto targetOps = state.getPayloadOps(getTarget());
1839 // If nothing to pack, propagate success.
1840 if (std::empty(targetOps)) {
1841 transformResults.set(cast<OpResult>(getPackedOp()),
1844 }
1845 // Fail on multi-op handles.
1846 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1847 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1848 return emitSilenceableError()
1849 << "requires target to map to exactly 1 LinalgOp (got "
1850 << llvm::range_size(targetOps) << ")";
1851 }
1852 // Fail on mismatched number of pack sizes.
1853 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1854 return emitSilenceableError()
1855 << "requires number of packed sizes match the number of loops ("
1856 << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1857 << ")";
1858 }
1859
1860 // Unpack handles to constants or actual SSA index values.
1861 SmallVector<OpFoldResult> packedSizes;
1863 state, *this, packedSizes, getMixedPackedSizes());
1864
1865 rewriter.setInsertionPoint(linalgOp);
1866 FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1867 if (failed(maybeResult))
1868 return emitDefiniteFailure("data tiling failed");
1869
1870 transformResults.set(cast<OpResult>(getPackedOp()),
1871 {maybeResult->packedLinalgOp.getOperation()});
1873}
1874
1875void transform::PackOp::getEffects(
1877 transform::consumesHandle(getTargetMutable(), effects);
1878 transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1879 transform::producesHandle(getOperation()->getOpResults(), effects);
1881}
1882
1883//===---------------------------------------------------------------------===//
1884// PackGreedilyOp.
1885//===---------------------------------------------------------------------===//
1886
1887LogicalResult transform::PackGreedilyOp::verify() {
1888 if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1889 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1890 << " is not a valid permutation";
1891 }
1892 // TODO: relax to allow empty once we have another strategy than just matmul.
1893 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1894 for (auto [s, nmo] :
1895 llvm::zip_equal(getMixedMatmulPackedSizes(),
1896 getMatmulPaddedSizesNextMultipleOf())) {
1897 std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1898 if (nmo != 0 &&
1899 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1900 return emitOpError() << "at most one of the packed_size and the "
1901 "padded_sizes_next_multiple_of can be nonzero "
1902 "for the matmul strategy";
1903 }
1904 }
1905 }
1906 return success();
1907}
1908
1910PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1911 transform::TransformResults &transformResults,
1914 for (Operation *op : state.getPayloadOps(getTarget())) {
1915 auto linalgOp = dyn_cast<LinalgOp>(op);
1916 if (!linalgOp)
1917 continue;
1918 // linalgOp will be replaced and the insertion point may be invalidated if
1919 // we set it before -> set it after.
1920 rewriter.setInsertionPointAfter(linalgOp);
1921 // Failing to pack greedily is perfectly fine.
1922 // In the future we will want to order packings according to some metric.
1923 FailureOr<PackResult> packResult = packMatmulGreedily(
1924 /*rewriter=*/rewriter,
1925 /*linalgOp=*/linalgOp,
1926 /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1927 /*mnkPaddedSizesNextMultipleOf=*/
1928 getMatmulPaddedSizesNextMultipleOf(),
1929 /*mnkOrder=*/getMatmulInnerDimsOrder());
1930 if (succeeded(packResult)) {
1931 results.push_back(packResult->packedLinalgOp);
1932 continue;
1933 }
1934 results.push_back(linalgOp);
1935 }
1936 transformResults.set(cast<OpResult>(getPackedOp()), results);
1938}
1939
1940SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1941 Builder b(getContext());
1942 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1943 b);
1944}
1945
1946void transform::PackGreedilyOp::getEffects(
1948 transform::consumesHandle(getTargetMutable(), effects);
1949 transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1950 transform::producesHandle(getOperation()->getOpResults(), effects);
1952}
1953
1954//===---------------------------------------------------------------------===//
1955// PackTransposeOp
1956//===---------------------------------------------------------------------===//
1957
1958LogicalResult transform::PackTransposeOp::verify() {
1959 if (!isPermutationVector(getInnerPerm())) {
1960 return emitOpError() << getInnerPermAttrName()
1961 << " is not a valid permutation";
1962 }
1963 if (!isPermutationVector(getOuterPerm())) {
1964 return emitOpError() << getOuterPermAttrName()
1965 << " is not a valid permutation";
1966 }
1967 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1968 return emitOpError() << " at least one of " << getInnerPermAttrName()
1969 << " or " << getOuterPermAttrName()
1970 << " must be specified";
1971 }
1972 return success();
1973}
1974
1975namespace {
1976enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1977} // namespace
1978
1979/// Return true if `permutation` is a valid permutation of the
1980/// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1981/// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1982/// This is the case when the `permutation` rank matches the rank expected by
1983/// `op` and `permutation` is itself a permutation vector.
1984/// Return true if either `op` or `permutation` are empty to allow a simpler
1985/// polymorphic implementation.
1986template <typename RelayoutOpTy>
1987static bool isValidPackingPermutation(
1988 RelayoutOpTy op, ArrayRef<int64_t> permutation,
1989 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1990 static_assert(
1991 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1992 "applies to only pack or unpack operations");
1993 if (!op || permutation.empty())
1994 return true;
1995 size_t innerRank = op.getInnerDimsPos().size();
1996 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1997 return permutation.size() == innerRank && isPermutationVector(permutation);
1998 // op.getOuterDimsPerm() may be empty, in which case it is identity.
1999 // Don't rely on it.
2000 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
2001 return permutation.size() == op.getSourceRank() &&
2002 isPermutationVector(permutation);
2003 }
2004 return permutation.size() == op.getDestRank() &&
2005 isPermutationVector(permutation);
2006}
2007
2009transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
2010 transform::TransformResults &transformResults,
2012 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
2013 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
2014 // Step 1. If nothing to pack, propagate success.
2015 if (std::empty(packOrUnpackOps)) {
2016 transformResults.set(cast<OpResult>(getPackedOp()), {});
2017 transformResults.set(cast<OpResult>(getPackOp()), {});
2018 transformResults.set(cast<OpResult>(getUnPackOp()), {});
2020 }
2021
2022 // Step 2. Bunch of runtime sanity check and error messages.
2023 // Step 2.1. Fail on multi-op handles.
2024 if (!llvm::hasSingleElement(packOrUnpackOps) ||
2025 !llvm::hasSingleElement(linalgOps)) {
2026 return emitSilenceableError()
2027 << "requires target to map to exactly 1 "
2028 "packing op and 1 packed op ("
2029 << "got " << llvm::range_size(packOrUnpackOps) << " and "
2030 << llvm::range_size(linalgOps) << ")";
2031 }
2032
2033 // Step 2.2. Fail on wrong type.
2034 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
2035 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
2036 if ((!packOp && !unPackOp)) {
2037 return emitSilenceableError() << "requires target to map to a "
2038 "linalg.pack or linalg.unpack";
2039 }
2040 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
2041 if (!linalgOpTarget)
2042 return emitSilenceableError() << "requires a LinalgOp target";
2043
2044 // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
2045 LinalgOp linalgOp;
2046 if (packOp && packOp.getResult().hasOneUse())
2047 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
2048 else if (unPackOp)
2049 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
2050 if (linalgOp != linalgOpTarget) {
2051 auto errorMsg =
2052 packOp ? StringLiteral{"not a single use by the LinalgOp target"}
2053 : StringLiteral{"not produced by the LinalgOp target"};
2054 return emitSilenceableError() << errorMsg;
2055 }
2056
2057 // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
2058 // PackOp.
2059 if (unPackOp) {
2060 assert(!packOp && "packOp must be null on entry when unPackOp is not null");
2061 OpOperand *packUse = linalgOp.getDpsInitOperand(
2062 cast<OpResult>(unPackOp.getSource()).getResultNumber());
2063 packOp = packUse->get().getDefiningOp<linalg::PackOp>();
2064 if (!packOp || !packOp.getResult().hasOneUse())
2065 return emitSilenceableError() << "could not find matching pack op";
2066 }
2067
2068 // Step 2.5. Fail if any permutation does not validate.
2069 for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
2070 ArrayRef<int64_t> perm =
2071 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
2072 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
2073 ? StringLiteral{"invalid outer_perm"}
2074 : StringLiteral{"invalid inner_perm"};
2075 if (!isValidPackingPermutation(packOp, perm, permType) ||
2076 !isValidPackingPermutation(unPackOp, perm, permType)) {
2077 Operation *packOrUnpackOp =
2078 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
2079 return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
2080 }
2081 }
2082
2083 // From here on, packOp and linalgOp are always present, unPackOp may or may
2084 // not be present.
2085 assert(packOp && linalgOp && "unexpected null op");
2086
2087 // Step 3. Actually transpose the ops.
2088 FailureOr<PackTransposeResult> res = packTranspose(
2089 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
2090 // Preconditions have been checked, it is an error to fail here.
2091 assert(succeeded(res) && "unexpected packTranspose failure");
2092
2093 // Step 4. Return results.
2094 transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
2095 transformResults.set(cast<OpResult>(getPackedOp()),
2096 {res->transposedLinalgOp});
2097 if (unPackOp) {
2098 transformResults.set(cast<OpResult>(getUnPackOp()),
2099 {res->transposedUnPackOp});
2100 } else {
2101 transformResults.set(cast<OpResult>(getUnPackOp()), {});
2102 }
2103
2105}
2106
2107//===---------------------------------------------------------------------===//
2108// PadOp
2109//===---------------------------------------------------------------------===//
2110
2111void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
2112 ArrayRef<int64_t> paddingDimensions,
2113 ArrayRef<int64_t> padToMultipleOf,
2114 ArrayRef<int64_t> nofoldFlags,
2115 ArrayRef<Attribute> transposePaddings,
2116 StringRef copyBackOp,
2117 bool usePrescribedTensorShapes) {
2118 auto resultType = transform::AnyOpType::get(b.getContext());
2119 return build(/*odsBuilder=*/b,
2120 /*result=*/result,
2121 /*types=*/TypeRange{resultType, resultType},
2122 /*target=*/target,
2123 /*padding_values=*/ArrayAttr(), // let inference handle this
2124 /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
2125 /*pad_to_multiple_of=*/ValueRange{},
2126 /*padToMultipleOf=*/
2127 (padToMultipleOf.empty()
2129 : b.getDenseI64ArrayAttr(padToMultipleOf)),
2130 /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
2131 /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
2132 /*copy_back_op=*/b.getStringAttr(copyBackOp),
2133 /*use_prescribed_tensor_shapes=*/
2134 usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
2135}
2136
2137void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
2138 ArrayRef<int64_t> paddingDimensions,
2139 ArrayRef<OpFoldResult> mixedPadToMultipleOf,
2140 ArrayRef<int64_t> nofoldFlags,
2141 ArrayRef<Attribute> transposePaddings,
2142 StringRef copyBackOp,
2143 bool usePrescribedTensorShapes) {
2144 auto resultType = transform::AnyOpType::get(b.getContext());
2145 SmallVector<int64_t> staticPadToMultipleOf;
2146 SmallVector<Value> dynamicPadToMultipleOf;
2147 dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
2148 staticPadToMultipleOf);
2149 return build(/*odsBuilder=*/b,
2150 /*result=*/result,
2151 /*types=*/TypeRange{resultType, resultType},
2152 /*target=*/target,
2153 /*padding_values=*/ArrayAttr(), // let inference handle this
2154 /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
2155 /*pad_to_multiple_of=*/dynamicPadToMultipleOf,
2156 /*padToMultipleOf=*/staticPadToMultipleOf,
2157 /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
2158 /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
2159 /*copy_back_op=*/copyBackOp,
2160 /*use_prescribed_tensor_shapes=*/usePrescribedTensorShapes);
2161}
2162
2163void PadOp::getEffects(
2165 consumesHandle(getTargetMutable(), effects);
2166 onlyReadsHandle(getPadToMultipleOfMutable(), effects);
2167 producesHandle(getOperation()->getOpResults(), effects);
2168 modifiesPayload(effects);
2169}
2170
2171SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
2172 Builder b(getContext());
2173 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
2174}
2175
2176DiagnosedSilenceableFailure
2177transform::PadOp::apply(transform::TransformRewriter &rewriter,
2178 transform::TransformResults &results,
2179 transform::TransformState &state) {
2180 auto transformOp = cast<TransformOpInterface>(getOperation());
2181 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
2182
2183 for (Operation *target : state.getPayloadOps(getTarget())) {
2184 auto linalgTarget = dyn_cast<LinalgOp>(target);
2185 if (!linalgTarget) {
2186 auto diag = emitSilenceableError() << "expected LinalgOp target";
2187 diag.attachNote(target->getLoc()) << "target op";
2188 return diag;
2189 }
2190
2191 // Convert the integer packing flags to booleans.
2192 SmallVector<bool> nofoldFlags;
2193 for (int64_t packPadding :
2194 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
2195 nofoldFlags.push_back(static_cast<bool>(packPadding));
2196
2197 // Convert the padding values to attributes.
2198 SmallVector<Attribute> paddingValues;
2199 for (auto const &[untypedAttr, elementOrTensorType] :
2200 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
2201
2202 if (matchPattern(untypedAttr, ub::m_Poison())) {
2203 paddingValues.push_back(untypedAttr);
2204 continue;
2205 }
2206 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2207 if (!attr) {
2208 emitOpError("expects padding values to be typed attributes or poison");
2210 }
2211 Type elementType = getElementTypeOrSelf(elementOrTensorType);
2212 // Try to parse string attributes to obtain an attribute of element type.
2213 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2214 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2215 stringAttr, getContext(), elementType,
2216 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2217 if (!parsedAttr || parsedAttr.getType() != elementType) {
2218 auto diag = this->emitOpError("expects a padding that parses to ")
2219 << elementType << ", got " << untypedAttr;
2220 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2222 }
2223 paddingValues.push_back(parsedAttr);
2224 continue;
2225 }
2226 // Otherwise, add the attribute directly.
2227 if (attr.getType() != elementType) {
2228 auto diag = this->emitOpError("expects a padding value of type ")
2229 << elementType << ", got " << attr;
2230 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2232 }
2233 paddingValues.push_back(attr);
2234 }
2235
2236 // Extract the transpose vectors.
2237 SmallVector<SmallVector<int64_t>> transposePaddings;
2238 for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2239 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2240 cast<ArrayAttr>(transposeVector)));
2241
2242 LinalgOp paddedOp;
2243 LinalgPaddingOptions options;
2244 options.paddingDimensions =
2245 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2246
2247 SmallVector<int64_t> padToMultipleOf;
2248 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
2249 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2250 if (!status.succeeded())
2251 return status;
2252 if (padToMultipleOf.empty())
2253 padToMultipleOf =
2254 SmallVector<int64_t>(options.paddingDimensions.size(), 1);
2255
2256 options.padToMultipleOf = padToMultipleOf;
2257 options.paddingValues = paddingValues;
2258 options.nofoldFlags = nofoldFlags;
2259 if (getCopyBackOp() ==
2260 bufferization::MaterializeInDestinationOp::getOperationName()) {
2261 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
2262 BufferizationMaterializeInDestination;
2263 } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2264 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
2265 } else if (getCopyBackOp() == kCopyOpNone) {
2266 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
2267 } else {
2268 llvm_unreachable("unsupported copy_back op");
2269 }
2270 // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2271 bool irChanged = false;
2272 if (getUsePrescribedTensorShapes() &&
2273 linalgTarget.hasPureTensorSemantics()) {
2274 OpBuilder::InsertionGuard g(rewriter);
2275 rewriter.setInsertionPoint(linalgTarget);
2276 for (OpOperand &operand : linalgTarget->getOpOperands()) {
2277 for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2278 if (ShapedType::isStatic(dim))
2279 continue;
2280 options.setSizeToPadTo(operand.getOperandNumber(), i,
2281 tensor::getMixedSize(rewriter,
2282 operand.get().getLoc(),
2283 operand.get(), i));
2284 irChanged = true;
2285 }
2286 }
2287 }
2288
2289 SmallVector<Value> replacements;
2290 SmallVector<tensor::PadOp> newPadOps;
2291 if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
2292 replacements, newPadOps))) {
2293 if (irChanged) {
2294 auto diag = emitDefiniteFailure() << "failed to pad op";
2295 diag.attachNote(target->getLoc()) << "target op";
2296 return diag;
2297 }
2298 auto diag = emitSilenceableError() << "failed to pad op";
2299 diag.attachNote(target->getLoc()) << "target op";
2300 return diag;
2301 }
2302
2303 // We need to perform our own replacement here because this API is still
2304 // used in patterns that "pad and hoist", for which the replacement values
2305 // need to be different.
2306 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2307 // that we have more composable abstractions.
2308 rewriter.replaceOp(linalgTarget, replacements);
2309 paddedOps.push_back(paddedOp);
2310 padOps.append(newPadOps.begin(), newPadOps.end());
2311 if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2312 for (Value v : replacements) {
2313 Operation *copyBackOp = v.getDefiningOp();
2314 if (!llvm::is_contained(copyBackOps, copyBackOp))
2315 copyBackOps.push_back(copyBackOp);
2316 }
2317 }
2318 }
2319
2320 results.set(cast<OpResult>(getPadded()), paddedOps);
2321 results.set(cast<OpResult>(getPad()), padOps);
2322 results.set(cast<OpResult>(getCopy()), copyBackOps);
2324}
2325
2326LogicalResult transform::PadOp::verify() {
2327 SmallVector<int64_t> nofoldFlags =
2328 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2329 if (any_of(nofoldFlags, [](int64_t packPadding) {
2330 return packPadding != 0 && packPadding != 1;
2331 })) {
2332 return emitOpError()
2333 << "expects nofold_flags to contain booleans (0/1), found "
2334 << getNofoldFlags();
2335 }
2336
2337 SmallVector<int64_t> paddingDimensions =
2338 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2339 if (any_of(paddingDimensions,
2340 [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2341 return emitOpError() << "expects padding_dimensions to contain positive "
2342 "integers, found "
2343 << getPaddingDimensions();
2344 }
2345 if (!getMixedPadToMultipleOf().empty()) {
2346 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2347 return emitOpError() << "expects as many multiples as padding_dimensions";
2348 }
2349 }
2350 ArrayAttr transposes = getTransposePaddings();
2351 for (Attribute attr : transposes) {
2352 SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2353 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2354 if (!std::is_permutation(sequence.begin(), sequence.end(),
2355 transpose.begin(), transpose.end())) {
2356 return emitOpError()
2357 << "expects transpose_paddings to be a permutation, found "
2358 << attr;
2359 }
2360 }
2361 if (getCopyBackOp() !=
2362 bufferization::MaterializeInDestinationOp::getOperationName() &&
2363 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2364 getCopyBackOp() != kCopyOpNone)
2365 return emitOpError() << "invalid copy_back_op";
2366 return success();
2367}
2368
2369//===---------------------------------------------------------------------===//
2370// PadTilingInterfaceOp
2371//===---------------------------------------------------------------------===//
2372
2373void transform::PadTilingInterfaceOp::build(OpBuilder &b,
2374 OperationState &result,
2375 Value target,
2376 ArrayRef<int64_t> paddingSizes,
2377 bool padToMultipleOf) {
2378 auto resultType = transform::AnyOpType::get(b.getContext());
2379 return build(/*odsBuilder=*/b,
2380 /*result=*/result,
2381 /*types=*/TypeRange{resultType, resultType},
2382 /*target=*/target,
2383 /*padding_values=*/ArrayAttr(), // let inference handle this
2384 /*padding_sizes=*/ValueRange{},
2385 /*paddingSizes=*/
2386 (paddingSizes.empty() ? DenseI64ArrayAttr()
2387 : b.getDenseI64ArrayAttr(paddingSizes)),
2388 /*pad_to_multiple_of=*/
2389 padToMultipleOf ? b.getUnitAttr() : nullptr);
2390}
2391
2392void transform::PadTilingInterfaceOp::build(
2393 OpBuilder &b, OperationState &result, Value target,
2394 ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
2395 auto resultType = transform::AnyOpType::get(b.getContext());
2396 SmallVector<int64_t> staticPaddingSizes;
2397 SmallVector<Value> dynamicPaddingSizes;
2398 dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
2399 staticPaddingSizes);
2400 return build(/*odsBuilder=*/b,
2401 /*result=*/result,
2402 /*types=*/TypeRange{resultType, resultType},
2403 /*target=*/target,
2404 /*padding_values=*/ArrayAttr(), // let inference handle this
2405 /*padding_sizes=*/dynamicPaddingSizes,
2406 /*paddingSizes=*/staticPaddingSizes,
2407 /*usePrescribedTensorShapes=*/padToMultipleOf);
2408}
2409
2410void transform::PadTilingInterfaceOp::getEffects(
2411 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2412 consumesHandle(getTargetMutable(), effects);
2413 onlyReadsHandle(getPaddingSizesMutable(), effects);
2414 producesHandle(getOperation()->getOpResults(), effects);
2415 modifiesPayload(effects);
2416}
2417
2418SmallVector<OpFoldResult>
2419transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2420 Builder b(getContext());
2421 return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2422}
2423
2424DiagnosedSilenceableFailure
2425transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2426 transform::TransformResults &results,
2427 transform::TransformState &state) {
2428 SmallVector<Operation *> paddedOps, padOps;
2429
2430 for (Operation *target : state.getPayloadOps(getTarget())) {
2431 auto targetOp = dyn_cast<TilingInterface>(target);
2432 if (!targetOp) {
2433 auto diag = emitSilenceableError() << "expected TilingInterface target";
2434 diag.attachNote(target->getLoc()) << "target op";
2435 return diag;
2436 }
2437
2438 // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
2439 // loopsToOperand map / C++ APIs to compute the effect of padding on
2440 // operands.
2441 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2442 auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
2443 "supported atm";
2444 diag.attachNote(target->getLoc()) << "target op";
2445 return diag;
2446 }
2447
2448 // Convert the padding values to attributes.
2449 SmallVector<Attribute> paddingValues;
2450 for (auto const &[untypedAttr, elementOrTensorType] :
2451 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2452 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2453 Type elementType = getElementTypeOrSelf(elementOrTensorType);
2454
2455 if (matchPattern(untypedAttr, ub::m_Poison())) {
2456 paddingValues.push_back(untypedAttr);
2457 continue;
2458 }
2459 if (!attr) {
2460 emitOpError("expects padding values to be typed attributes or poison");
2462 }
2463 // Try to parse string attributes to obtain an attribute of element type.
2464 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2465 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2466 stringAttr, getContext(), elementType,
2467 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2468 if (!parsedAttr || parsedAttr.getType() != elementType) {
2469 auto diag = this->emitOpError("expects a padding that parses to ")
2470 << elementType << ", got " << attr;
2471 diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2473 }
2474 paddingValues.push_back(parsedAttr);
2475 continue;
2476 }
2477 // Otherwise, add the attribute directly.
2478 if (attr.getType() != elementType) {
2479 auto diag = this->emitOpError("expects a padding value of type ")
2480 << elementType << ", got " << attr;
2481 diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2483 }
2484 paddingValues.push_back(attr);
2485 }
2486
2487 // Set options.
2488 PadTilingInterfaceOptions options;
2489 options.setPaddingValues(paddingValues)
2490 .setPaddingSizes(getMixedPaddingSizes())
2491 .setPadToMultipleOf(getPadToMultipleOf());
2492
2493 OpBuilder::InsertionGuard g(rewriter);
2494 rewriter.setInsertionPointAfter(targetOp);
2495 auto maybePadOps = rewriteAsPaddedOp(
2496 rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
2497 if (failed(maybePadOps)) {
2498 auto diag = emitSilenceableError() << "failed to pad op";
2499 diag.attachNote(target->getLoc()) << "target op";
2500 return diag;
2501 }
2502 const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
2503
2504 // Set transform results.
2505 paddedOps.push_back(paddedOp);
2506 padOps.append(paddedOperands.begin(), paddedOperands.end());
2507 rewriter.replaceOp(targetOp.getOperation(), slicedResults);
2508 }
2509
2510 results.set(cast<OpResult>(getPadded()), paddedOps);
2511 results.set(cast<OpResult>(getPad()), padOps);
2513}
2514
2515LogicalResult transform::PadTilingInterfaceOp::verify() { return success(); }
2516
2517//===---------------------------------------------------------------------===//
2518// HoistPadOp
2519//===---------------------------------------------------------------------===//
2520
2521DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2522 transform::TransformRewriter &rewriter,
2523 transform::TransformResults &transformResults,
2524 transform::TransformState &state) {
2525 auto targetOps = state.getPayloadOps(getTarget());
2526 auto loopOps = state.getPayloadOps(getLoop());
2527 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2528 return emitDefiniteFailure()
2529 << "requires exactly one target and one loop handle (got "
2530 << llvm::range_size(targetOps) << " and "
2531 << llvm::range_size(loopOps) << ")";
2532 }
2533
2534 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2535 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2536 if (!padOp || !loopOp)
2537 return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2538
2539 FailureOr<linalg::detail::PackingResult> result =
2540 linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2541 getTranspose());
2542 if (failed(result))
2543 return emitDefiniteFailure() << "could not build packing loop nest";
2544
2545 if (result->clonedLoopIvs.empty()) {
2546 transformResults.set(cast<OpResult>(getPackingLoop()),
2547 {result->hoistedPadOp.getOperation()});
2549 }
2550 auto outerPackedLoop =
2551 scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2552 transformResults.set(cast<OpResult>(getPackingLoop()),
2553 {outerPackedLoop.getOperation()});
2555}
2556
2557LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
2558 ArrayRef<int64_t> transpose = getTranspose();
2559 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2560 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2561 transpose.end())) {
2562 return emitOpError() << "expects transpose to be a permutation, found "
2563 << getTranspose();
2564 }
2565 return success();
2566}
2567
2568void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2569 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2570 transform::onlyReadsHandle(getTargetMutable(), effects);
2571 transform::onlyReadsHandle(getLoopMutable(), effects);
2572 transform::producesHandle(getOperation()->getOpResults(), effects);
2574}
2575
2576DiagnosedSilenceableFailure
2577transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2578 tensor::PadOp target,
2579 transform::ApplyToEachResultList &results,
2580 transform::TransformState &state) {
2581 tensor::PadOp hoistedPadOp;
2582 SmallVector<TransposeOp> transposeOps;
2583 FailureOr<Value> result =
2584 hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2585 hoistedPadOp, transposeOps);
2586 if (succeeded(result)) {
2587 // We need to perform our own replacement here because this API is still
2588 // used in patterns that "pad and hoist", for which the replacement values
2589 // need to be different.
2590 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2591 // that we have more composable abstractions.
2592 rewriter.replaceOp(target, *result);
2593 results.push_back(hoistedPadOp);
2595 }
2596 return emitDefaultSilenceableFailure(target);
2597}
2598
2599LogicalResult transform::HoistPadOp::verify() {
2600 ArrayRef<int64_t> transpose = getTranspose();
2601 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2602 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2603 transpose.end())) {
2604 return emitOpError() << "expects transpose to be a permutation, found "
2605 << getTranspose();
2606 }
2607 return success();
2608}
2609
2610//===----------------------------------------------------------------------===//
2611// PromoteOp
2612//===----------------------------------------------------------------------===//
2613
2614DiagnosedSilenceableFailure
2615transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2616 LinalgOp target,
2617 transform::ApplyToEachResultList &results,
2618 transform::TransformState &state) {
2619 LinalgPromotionOptions promotionOptions;
2620 if (!getOperandsToPromote().empty())
2621 promotionOptions = promotionOptions.setOperandsToPromote(
2622 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2623 if (getUseFullTilesByDefault())
2624 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2625 getUseFullTilesByDefault());
2626 if (getUseOriginalSubviewSize())
2627 promotionOptions =
2628 promotionOptions.setUseOriginalSubviewSize(getUseOriginalSubviewSize());
2629 if (getUseAlloca())
2630 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2631 if (!getUseFullTileBuffers().empty())
2632 promotionOptions = promotionOptions.setUseFullTileBuffers(
2633 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2634 if (getAlignment().has_value())
2635 promotionOptions = promotionOptions.setAlignment(*getAlignment());
2636 if (getMemorySpace().has_value())
2637 promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2638
2639 if (getMapping().has_value()) {
2640 // The mapping should only contain an element
2641 auto mapping = *getMapping();
2642 if (mapping.size() > 1)
2643 return emitDefaultDefiniteFailure(target);
2644
2645 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2646
2647 if (addressSpace.getAddressSpace() ==
2648 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2649 promotionOptions =
2650 promotionOptions
2654 .setUseFullTileBuffers({false, false});
2655 } else if (addressSpace.getAddressSpace() ==
2656 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2657 promotionOptions =
2658 promotionOptions
2662 .setUseFullTileBuffers({false, false});
2663 } else {
2664 return emitDefaultDefiniteFailure(target);
2665 }
2666 }
2667
2668 if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2669 return emitDefaultDefiniteFailure(target);
2670
2671 rewriter.setInsertionPoint(target);
2672 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2673 if (failed(res))
2674 return emitDefaultDefiniteFailure(target);
2675 results.push_back(target);
2677}
2678
2679//===----------------------------------------------------------------------===//
2680// ReplaceOp
2681//===----------------------------------------------------------------------===//
2682
2683DiagnosedSilenceableFailure
2684transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2685 TransformResults &transformResults,
2686 TransformState &state) {
2687 auto payload = state.getPayloadOps(getTarget());
2688
2689 // Check for invalid targets.
2690 for (Operation *target : payload) {
2691 if (target->getNumOperands() > 0)
2692 return emitDefiniteFailure() << "expected target without operands";
2693 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2694 target->getNumRegions() > 0)
2695 return emitDefiniteFailure()
2696 << "expected target that is isolated from above";
2697 }
2698
2699 // Clone and replace.
2700 Operation *pattern = &getBodyRegion().front().front();
2701 SmallVector<Operation *> replacements;
2702 for (Operation *target : payload) {
2703 if (getOperation()->isAncestor(target))
2704 continue;
2705 rewriter.setInsertionPoint(target);
2706 Operation *replacement = rewriter.clone(*pattern);
2707 rewriter.replaceOp(target, replacement->getResults());
2708 replacements.push_back(replacement);
2709 }
2710 transformResults.set(cast<OpResult>(getReplacement()), replacements);
2712}
2713
2714void transform::ReplaceOp::getEffects(
2715 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2716 consumesHandle(getTargetMutable(), effects);
2717 producesHandle(getOperation()->getOpResults(), effects);
2718 modifiesPayload(effects);
2719}
2720
2721LogicalResult transform::ReplaceOp::verify() {
2722 if (!getBodyRegion().hasOneBlock())
2723 return emitOpError() << "expected one block";
2724 if (std::distance(getBodyRegion().front().begin(),
2725 getBodyRegion().front().end()) != 1)
2726 return emitOpError() << "expected one operation in block";
2727 Operation *replacement = &getBodyRegion().front().front();
2728 if (replacement->getNumOperands() > 0)
2729 return replacement->emitOpError()
2730 << "expected replacement without operands";
2731 if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2732 replacement->getNumRegions() > 0)
2733 return replacement->emitOpError()
2734 << "expect op that is isolated from above";
2735 return success();
2736}
2737
2738//===----------------------------------------------------------------------===//
2739// ScalarizeOp
2740//===----------------------------------------------------------------------===//
2741
2742DiagnosedSilenceableFailure
2743transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2744 LinalgOp target,
2745 transform::ApplyToEachResultList &results,
2746 transform::TransformState &state) {
2747 scf::SCFTilingOptions tilingOptions;
2748 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2749 SmallVector<OpFoldResult> tileSizes;
2750 Location loc = target.getLoc();
2751 SmallVector<OpFoldResult> allShapeSizes =
2752 target.createFlatListOfOperandDims(b, loc);
2753 AffineMap map = target.getShapesToLoopsMap();
2754 if (!map)
2755 return tileSizes;
2756 SmallVector<OpFoldResult> shapeSizes =
2758 allShapeSizes);
2759 // If the shape size is dynamic, tile by 1.
2760 // Otherwise, do not tile (i.e. tile size 0).
2761 for (OpFoldResult shapeSize : shapeSizes) {
2762 tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2763 : b.getIndexAttr(1));
2764 }
2765 return tileSizes;
2766 });
2767 rewriter.setInsertionPoint(target);
2768 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2769 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2770 if (failed(maybeTilingResult))
2771 return emitDefaultDefiniteFailure(target);
2772
2773 if (target->getNumResults())
2774 rewriter.replaceOp(target, maybeTilingResult->replacements);
2775 else
2776 rewriter.eraseOp(target);
2777
2778 results.reserve(maybeTilingResult->tiledOps.size());
2779 for (Operation *tiled : maybeTilingResult->tiledOps)
2780 results.push_back(tiled);
2782}
2783
2784//===----------------------------------------------------------------------===//
2785// ConvertToLoopsOp
2786//===----------------------------------------------------------------------===//
2787
2788DiagnosedSilenceableFailure
2789transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2790 transform::TransformResults &results,
2791 transform::TransformState &state) {
2792 SmallVector<Operation *> loops;
2793 for (Operation *target : state.getPayloadOps(getTarget())) {
2794 auto tilingOp = dyn_cast<TilingInterface>(*target);
2795 if (!tilingOp) {
2796 DiagnosedSilenceableFailure diag =
2797 emitSilenceableError()
2798 << "expected the payload to implement TilingInterface";
2799 diag.attachNote(target->getLoc()) << "payload op";
2800 return diag;
2801 }
2802 rewriter.setInsertionPoint(target);
2803 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2804 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2805 if (failed(generatedLoops))
2806 return emitDefaultDefiniteFailure(target);
2807 for (scf::ForOp &loop : *generatedLoops) {
2808 loops.push_back(loop.getOperation());
2809 }
2810 rewriter.eraseOp(target);
2811 }
2812 results.set(cast<OpResult>(getResult()), loops);
2814}
2815
2816//===----------------------------------------------------------------------===//
2817// RewriteInDestinationPassingStyleOp
2818//===----------------------------------------------------------------------===//
2819
2820DiagnosedSilenceableFailure
2821transform::RewriteInDestinationPassingStyleOp::applyToOne(
2822 transform::TransformRewriter &rewriter, Operation *target,
2823 transform::ApplyToEachResultList &results,
2824 transform::TransformState &state) {
2825 rewriter.setInsertionPoint(target);
2826 FailureOr<Operation *> maybeResult =
2828 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2829 [&rewriter](auto op) {
2830 return rewriteInDestinationPassingStyle(rewriter, op);
2831 });
2832 if (failed(maybeResult))
2833 return emitDefaultSilenceableFailure(target);
2834 results.push_back(*maybeResult);
2836}
2837
2838//===----------------------------------------------------------------------===//
2839// SplitOp
2840//===----------------------------------------------------------------------===//
2841
2842DiagnosedSilenceableFailure
2843SplitOp::apply(transform::TransformRewriter &rewriter,
2844 TransformResults &results, TransformState &state) {
2845 // Collect the dynamic split points if provided.
2846 SmallVector<Operation *> payload =
2847 llvm::to_vector(state.getPayloadOps(getTarget()));
2848
2849 bool isMultiwaySplit = getMultiway();
2850
2851 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2852 return mlir::emitSilenceableFailure(getLoc())
2853 << "requires exactly one target when "
2854 "multiway split is enabled (got "
2855 << llvm::range_size(payload) << ")";
2856 }
2857
2858 SmallVector<OpFoldResult> chunkSizes;
2859
2860 if (!isMultiwaySplit)
2861 chunkSizes.reserve(payload.size());
2862
2863 if (getDynamicChunkSizes()) {
2865 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2866 chunkSizes = llvm::map_to_vector(
2867 state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2868 if (op->getNumResults() != 1 ||
2869 !op->getResult(0).getType().isIndex()) {
2870 diag = emitSilenceableError()
2871 << "expected dynamic split point handle to point to a "
2872 "single-result index-typed op";
2873 diag.attachNote(op->getLoc()) << "dynamic split point";
2874 }
2875 return OpFoldResult(op->getResult(0));
2876 });
2877 } else {
2878 chunkSizes = llvm::map_to_vector(
2879 state.getParams(getDynamicChunkSizes()),
2880 [](Attribute attr) { return OpFoldResult(attr); });
2881 }
2882 if (diag.isSilenceableFailure())
2883 return diag;
2884
2885 // For multiway split, a single payload is expected to have multiple
2886 // split points.
2887 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2888 return emitDefiniteFailure()
2889 << "expected the dynamic split point handle to point to as "
2890 "many operations ("
2891 << chunkSizes.size() << ") as the target handle ("
2892 << payload.size() << ")";
2893 }
2894 } else {
2895 chunkSizes.resize(payload.size(),
2896 rewriter.getIndexAttr(getStaticChunkSizes()));
2897 }
2898
2899 auto checkStructuredOpAndDimensions =
2900 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2901 if (!linalgOp) {
2902 auto diag = emitSilenceableError() << "only applies to structured ops";
2903 diag.attachNote(loc) << "target op";
2904 return diag;
2905 }
2906
2907 if (getDimension() >= linalgOp.getNumLoops()) {
2908 auto diag = emitSilenceableError() << "dimension " << getDimension()
2909 << " does not exist in target op";
2910 diag.attachNote(loc) << "target op";
2911 return diag;
2912 }
2914 };
2915
2916 auto checkFailureInSplitting =
2917 [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2918 if (hasFailed) {
2919 auto diag = emitDefiniteFailure() << "internal failure in splitting";
2920 diag.attachNote(loc) << "target op";
2921 return diag;
2922 }
2924 };
2925
2926 SmallVector<Operation *> opList;
2927 if (isMultiwaySplit) {
2928
2929 // Split a single target operation at multiple points.
2930 TilingInterface head, tail;
2931 Operation *target = payload.front();
2932
2933 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2934
2935 // Check that the target is a valid LinalgOp with correct dimensions.
2936 DiagnosedSilenceableFailure diag =
2937 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2938 if (diag.isSilenceableFailure())
2939 return diag;
2940
2941 for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2942
2943 if (idx > 0)
2944 target = tail.getOperation();
2945
2946 if (!target)
2947 break;
2948
2949 linalgOp = cast<LinalgOp>(target);
2950 Location loc = target->getLoc();
2951
2952 rewriter.setInsertionPoint(linalgOp);
2953 std::tie(head, tail) = linalg::splitOp(
2954 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2955 getDimension(), chunkSize);
2956
2957 // Propagate errors.
2958 DiagnosedSilenceableFailure diag =
2959 checkFailureInSplitting(!head && !tail, loc);
2960 if (diag.isDefiniteFailure())
2961 return diag;
2962
2963 opList.push_back(head.getOperation());
2964 }
2965
2966 // Append any leftover parts to the end of the result list.
2967 if (tail)
2968 opList.push_back(tail.getOperation());
2969
2970 } else {
2971 // Split each target operation.
2972 SmallVector<Operation *> first, second;
2973 Operation *noSecondPart = nullptr;
2974 for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2975 Operation *target = std::get<0>(pair);
2976 Location loc = target->getLoc();
2977 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2978 DiagnosedSilenceableFailure diag =
2979 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2980
2981 if (diag.isSilenceableFailure())
2982 return diag;
2983
2984 rewriter.setInsertionPoint(linalgOp);
2985 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2986 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2987 getDimension(), std::get<1>(pair));
2988
2989 // Propagate errors.
2990 DiagnosedSilenceableFailure diagSplit =
2991 checkFailureInSplitting(!first.back() && !second.back(), loc);
2992 if (diagSplit.isDefiniteFailure())
2993 return diag;
2994
2995 // Do not add null second parts.
2996 if (!second.back()) {
2997 noSecondPart = target;
2998 second.pop_back();
2999 }
3000 }
3001
3002 if (second.size() != first.size() && !second.empty()) {
3003 auto diag = emitSilenceableError()
3004 << "splitting does not produce the second part for a subset "
3005 "of targets";
3006 diag.attachNote()
3007 << "expected splitting to produce the second part of all "
3008 "or none of the targets";
3009 diag.attachNote(noSecondPart->getLoc())
3010 << "first target with no second part";
3011 return diag;
3012 }
3013
3014 opList.append(first);
3015 if (!second.empty())
3016 opList.append(second);
3017 }
3018 results.set(cast<OpResult>(getSplitList()), opList);
3020}
3021
3022void SplitOp::getEffects(
3023 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3024 consumesHandle(getTargetMutable(), effects);
3025 if (getDynamicChunkSizes())
3026 onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
3027 producesHandle(getOperation()->getOpResults(), effects);
3028 modifiesPayload(effects);
3029}
3030
3031ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
3032 OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
3033 IntegerAttr staticChunkSizes;
3034 if (parser.parseOperand(target) || parser.parseKeyword("after"))
3035 return failure();
3036
3037 OptionalParseResult dynamicPointParseResult =
3038 parser.parseOptionalOperand(dynamicChunkSizes);
3039 if (!dynamicPointParseResult.has_value()) {
3040 int64_t staticChunkSizesValue;
3041 if (failed(parser.parseInteger(staticChunkSizesValue)))
3042 return failure();
3043
3044 staticChunkSizes =
3045 parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
3046 }
3047
3048 Type targetType;
3049 if (parser.parseOptionalAttrDict(result.attributes) ||
3050 parser.parseColonType(targetType) ||
3051 parser.resolveOperand(target, targetType, result.operands)) {
3052 return failure();
3053 }
3054 if (dynamicPointParseResult.has_value()) {
3055 Type chunkSizesType;
3056 if (failed(*dynamicPointParseResult) || parser.parseComma() ||
3057 parser.parseType(chunkSizesType) ||
3058 parser.resolveOperand(dynamicChunkSizes, chunkSizesType,
3059 result.operands)) {
3060 return failure();
3061 }
3062
3063 staticChunkSizes =
3064 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
3065 }
3066
3067 result.addAttribute(
3068 SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
3069 staticChunkSizes);
3070 result.addTypes(targetType);
3071 return success();
3072}
3073
3074void SplitOp::print(OpAsmPrinter &printer) {
3075 printer << " " << getTarget() << " after ";
3076 int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
3077 if (staticChunkSize != ShapedType::kDynamic)
3078 printer << staticChunkSize;
3079 else
3080 printer << getDynamicChunkSizes();
3081 printer << " ";
3082 printer.printOptionalAttrDict(getOperation()->getAttrs(),
3083 {getStaticChunkSizesAttrName()});
3084 printer << " : " << getTarget().getType();
3085 if (staticChunkSize == ShapedType::kDynamic)
3086 printer << ", " << getDynamicChunkSizes().getType();
3087}
3088
3089LogicalResult SplitOp::verify() {
3090 if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
3091 (getDynamicChunkSizes() == nullptr)) {
3092 return emitOpError() << "expects either a dynamic or a static split "
3093 "point to be provided";
3094 }
3095 return success();
3096}
3097
3098//===----------------------------------------------------------------------===//
3099// SplitReductionOp
3100//===----------------------------------------------------------------------===//
3101
3102void transform::SplitReductionOp::build(
3103 OpBuilder &builder, OperationState &result, Value target,
3104 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
3105 bool useScalingAlgorithm, bool useAlloc) {
3106 MLIRContext *ctx = builder.getContext();
3107 result.addOperands(target);
3108 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
3109 builder.getI64IntegerAttr(splitFactor));
3110 result.addAttribute(
3111 SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
3112 builder.getI64IntegerAttr(insertSplitDimension));
3113 if (innerParallel) {
3114 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
3115 builder.getUnitAttr());
3116 }
3117 if (useScalingAlgorithm) {
3118 result.addAttribute(
3119 SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
3120 builder.getUnitAttr());
3121 }
3122 if (useAlloc) {
3123 result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
3124 builder.getUnitAttr());
3125 }
3126 auto resultType = transform::AnyOpType::get(ctx);
3127 result.addTypes({resultType, resultType, resultType, resultType});
3128}
3129
3130DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
3131 transform::TransformRewriter &rewriter, LinalgOp target,
3132 transform::ApplyToEachResultList &results,
3133 transform::TransformState &state) {
3134 ControlSplitReductionFn splitFn = [&](LinalgOp) {
3135 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
3136 unsigned(getInsertSplitDimension()),
3137 bool(getInnerParallel())};
3138 };
3139 rewriter.setInsertionPoint(target);
3140 FailureOr<SplitReductionResult> splitResult =
3141 (getUseScalingAlgorithm())
3142 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
3143 : splitReduction(rewriter, target, splitFn, getUseAlloc());
3144 if (failed(splitResult))
3145 return emitDefaultDefiniteFailure(target);
3146
3147 results.push_back(splitResult->initOrAlloc);
3148 results.push_back(splitResult->fillOp);
3149 results.push_back(splitResult->splitLinalgOp);
3150 results.push_back(splitResult->resultCombiningLinalgOp);
3152}
3153
3154//===----------------------------------------------------------------------===//
3155// TileReductionUsingForOp
3156//===----------------------------------------------------------------------===//
3157
3158void transform::TileReductionUsingForOp::build(
3159 OpBuilder &builder, OperationState &result, Value target,
3160 ArrayRef<int64_t> staticTileSizes) {
3161 // Call the default builder.
3162 // This is future-proof re mixed static-dynamic and setting up the proper
3163 // operands segment sizes attributes for multiple variadic operands.
3164 // In the absence of this, horrible bugs ensue.
3165 // TODO: support mixed static-dynamic (see TileUsingForallOp).
3166 MLIRContext *ctx = builder.getContext();
3167 auto opTy = transform::AnyOpType::get(ctx);
3168 auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
3169 build(builder, result,
3170 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3171 /*target=*/target,
3172 /*reduction_dims=*/nullptr,
3173 /*tile_sizes=*/staticTileSizesAttr);
3174}
3175
3176DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
3177 transform::TransformRewriter &rewriter, Operation *target,
3178 transform::ApplyToEachResultList &results,
3179 transform::TransformState &state) {
3180 rewriter.setInsertionPoint(target);
3181
3182 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3183 if (!partialReductionOp) {
3185 target->getLoc(),
3186 "Operation should implement PartialReductionOpInterface");
3187 }
3188
3189 SmallVector<unsigned> reductionDims =
3190 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3191 if (reductionDims.empty()) {
3192 for (auto [idx, iteratorType] :
3193 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3194 if (iteratorType == utils::IteratorType::reduction)
3195 reductionDims.push_back(idx);
3196 }
3197 }
3198
3199 scf::SCFTilingOptions options;
3200 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
3201 options.setReductionTilingStrategy(
3203 options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
3204 options.setReductionDims(reductionDims);
3205 FailureOr<scf::SCFTilingResult> result =
3206 scf::tileUsingSCF(rewriter, partialReductionOp, options);
3207
3208 if (failed(result)) {
3209 return emitSilenceableFailure(getLoc(),
3210 "failed to tile using partial reduction");
3211 }
3212 rewriter.replaceOp(target, result->replacements);
3213 for (Value initValue : result->initialValues)
3214 results.push_back(initValue.getDefiningOp());
3215 for (auto *parallelTiledOp : result->tiledOps)
3216 results.push_back(parallelTiledOp);
3217 for (auto *mergeOp : result->mergeOps)
3218 results.push_back(mergeOp);
3219 results.push_back(result->loops.front());
3221}
3222
3223//===----------------------------------------------------------------------===//
3224// TileReductionUsingForallOp
3225//===----------------------------------------------------------------------===//
3226
3227void transform::TileReductionUsingForallOp::build(
3228 OpBuilder &builder, OperationState &result, Value target,
3229 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3230 ArrayAttr mapping) {
3231 // Call the default builder.
3232 // This is future-proof re mixed static-dynamic and setting up the proper
3233 // operands segment sizes attributes for multiple variadic operands.
3234 // In the absence of this, horrible bugs ensue.
3235 // TODO: support mixed static-dynamic (see TileUsingForallOp).
3236 MLIRContext *ctx = builder.getContext();
3237 auto opTy = transform::AnyOpType::get(ctx);
3238 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3239 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3240 build(builder, result,
3241 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3242 /*target=*/target,
3243 /*reduction_dims=*/{},
3244 /*num_threads=*/staticNumThreadsAttr,
3245 /*tile_sizes=*/staticTileSizesAttr,
3246 /*mapping=*/mapping);
3247}
3248
3249DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3250 transform::TransformRewriter &rewriter, Operation *target,
3251 transform::ApplyToEachResultList &results,
3252 transform::TransformState &state) {
3253 rewriter.setInsertionPoint(target);
3254
3255 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3256 if (!partialReductionOp) {
3258 target->getLoc(),
3259 "Operation should implement PartialReductionOpInterface");
3260 }
3261 SmallVector<OpFoldResult> numThreads =
3262 getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
3263 SmallVector<OpFoldResult> tileSizes =
3265
3266 scf::SCFTilingOptions options;
3267 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3268 options.setReductionTilingStrategy(
3270 if (!getNumThreads().empty()) {
3271 options.setNumThreads(numThreads);
3272 } else {
3273 options.setTileSizes(tileSizes);
3274 }
3275 if (auto mapping = getMapping()) {
3276 options.setMapping(mapping.value().getValue());
3277 }
3278 SmallVector<unsigned> reductionDims =
3279 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3280 if (reductionDims.empty()) {
3281 for (auto [idx, iteratorType] :
3282 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3283 if (iteratorType == utils::IteratorType::reduction)
3284 reductionDims.push_back(idx);
3285 }
3286 }
3287 options.setReductionDims(reductionDims);
3288 FailureOr<scf::SCFTilingResult> result =
3289 scf::tileUsingSCF(rewriter, partialReductionOp, options);
3290
3291 if (failed(result)) {
3292 auto diag = emitSilenceableError() << "could not tile reduction";
3293 return diag;
3294 }
3295 rewriter.replaceOp(target, result->replacements);
3296
3297 for (Value initValue : result->initialValues)
3298 results.push_back(initValue.getDefiningOp());
3299 for (auto *parallelTiledOp : result->tiledOps)
3300 results.push_back(parallelTiledOp);
3301 for (auto *mergeOp : result->mergeOps)
3302 results.push_back(mergeOp);
3303 results.push_back(result->loops.front());
3305}
3306
3307//===----------------------------------------------------------------------===//
3308// ContinuousTileSizesOp
3309//===----------------------------------------------------------------------===//
3310
3311DiagnosedSilenceableFailure
3312transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3313 TransformResults &transformResults,
3314 TransformState &state) {
3315
3316 SmallVector<Operation *> targetOps =
3317 llvm::to_vector(state.getPayloadOps(getTarget()));
3318
3319 if (!llvm::hasSingleElement(targetOps)) {
3320 return mlir::emitSilenceableFailure(getLoc())
3321 << "requires exactly one target (got " << llvm::range_size(targetOps)
3322 << ")";
3323 }
3324
3325 Operation *target = *targetOps.begin();
3326 auto linalgOp = dyn_cast<LinalgOp>(target);
3327 auto tileableOp = dyn_cast<TilingInterface>(target);
3328
3329 if (!linalgOp)
3330 return emitDefiniteFailure() << "expected Linalg Op";
3331
3332 OpBuilder builder(linalgOp.getContext());
3333
3334 if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
3335 if (linalgOp.hasDynamicShape()) {
3336 auto diag = emitSilenceableError()
3337 << "cannot compute parametric tile sizes for dynamically "
3338 "shaped payload op";
3339 diag.attachNote(linalgOp->getLoc()) << "payload op";
3340 return diag;
3341 }
3342
3343 FailureOr<StaticContinuousTileSizeSpecification> spec =
3344 computeStaticContinuousTileSizes(linalgOp, getDimension(),
3345 getTargetSize());
3346 if (failed(spec)) {
3347 return emitSilenceableError()
3348 << "failed to compute multi-size tiling sizes";
3349 }
3350
3351 SmallVector<int64_t> chunkSizes;
3352
3353 for (auto &&[tileSize, tripCount] :
3354 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3355 chunkSizes.push_back(tileSize * tripCount);
3356
3357 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3358 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3359 return builder.getI64IntegerAttr(value);
3360 });
3361 };
3362 transformResults.setParams(cast<OpResult>(getTileSizes()),
3363 getI64AttrsFromI64(spec->tileSizes));
3364 transformResults.setParams(cast<OpResult>(getChunkSizes()),
3365 getI64AttrsFromI64(chunkSizes));
3366
3368 }
3369
3370 builder.setInsertionPoint(linalgOp);
3371
3372 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
3373 unsigned dimension = getDimension();
3374
3375 FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
3376 builder, tileableOp, dimension, targetSize, true);
3377 if (failed(spec)) {
3378 return emitSilenceableError() << "could not generate tile size computation";
3379 }
3380
3381 AffineExpr s0 = builder.getAffineSymbolExpr(0);
3382 AffineExpr s1 = builder.getAffineSymbolExpr(1);
3383 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3384 return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
3385 ofrs);
3386 };
3387
3388 SmallVector<Value> chunkSizes;
3389 Value splitPoint;
3390 for (auto &&[tileSize, tripCount] :
3391 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3392 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3393 chunkSizes.push_back(splitPoint);
3394 }
3395
3396 auto getDefiningOps = [&](ArrayRef<Value> values) {
3397 return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3398 return value.getDefiningOp();
3399 });
3400 };
3401
3402 transformResults.set(cast<OpResult>(getTileSizes()),
3403 getDefiningOps(spec->tileSizes));
3404 transformResults.set(cast<OpResult>(getChunkSizes()),
3405 getDefiningOps(chunkSizes));
3406
3408}
3409
3410LogicalResult transform::ContinuousTileSizesOp::verify() {
3411
3412 if (getTileSizes().getType() != getChunkSizes().getType()) {
3413 return emitOpError() << "expects all results type to be the same";
3414 }
3415
3416 return success();
3417}
3418
3419void transform::ContinuousTileSizesOp::getEffects(
3420 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3421 if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
3422 onlyReadsPayload(effects);
3423 else
3424 modifiesPayload(effects);
3425 onlyReadsHandle(getTargetMutable(), effects);
3426 producesHandle(getOperation()->getOpResults(), effects);
3427}
3428
3430 Type targetType, Type tileSizes,
3431 Type) {
3432 printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes});
3433}
3434
3436 Type &targetType,
3437 Type &tileSizesType,
3438 Type &chunkSizesType) {
3439 FunctionType funcType;
3440 llvm::SMLoc typeLoc = parser.getCurrentLocation();
3441 if (failed(parser.parseType<FunctionType>(funcType)))
3442 return failure();
3443
3444 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3445 parser.emitError(typeLoc) << "expects a trailing functional type with one "
3446 "argument and one result";
3447 }
3448 targetType = funcType.getInput(0);
3449 tileSizesType = chunkSizesType = funcType.getResult(0);
3450
3451 return success();
3452}
3453
3454//===----------------------------------------------------------------------===//
3455// TileUsingForOp
3456//===----------------------------------------------------------------------===//
3457
3458void transform::TileUsingForOp::build(
3459 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3460 Value target, ArrayRef<int64_t> staticTileSizes,
3461 ArrayRef<int64_t> interchange,
3462 std::optional<ArrayRef<bool>> scalableSizes) {
3463 return build(builder, result, loopTypes,
3464 /*target=*/target,
3465 /*mixedTileSizes=*/
3466 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3467 interchange, scalableSizes);
3468}
3469
3470void transform::TileUsingForOp::build(
3471 OpBuilder &builder, OperationState &result, Value target,
3472 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3473 std::optional<ArrayRef<bool>> scalableSizes) {
3474 build(builder, result, target,
3475 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3476 interchange, scalableSizes);
3477}
3478
3479void transform::TileUsingForOp::build(
3480 OpBuilder &builder, OperationState &result, Value target,
3481 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3482 std::optional<ArrayRef<bool>> scalableSizes) {
3483 // Loop types are automaticaly splat by the callee, setting up one is
3484 // enough.
3485 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
3486 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3487 scalableSizes);
3488}
3489
3490void transform::TileUsingForOp::build(
3491 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3492 Value target, ArrayRef<OpFoldResult> mixedTileSizes,
3493 ArrayRef<int64_t> interchange,
3494 std::optional<ArrayRef<bool>> scalableSizes) {
3495 SmallVector<int64_t> staticTileSizes;
3496 SmallVector<Value> dynamicTileSizes;
3497 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3498 // Call the default builder which sets up the proper operands segment sizes
3499 // attributes for multiple variadic operands. In the absence of this,
3500 // horrible bugs ensue.
3501 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3502 unsigned numExpectedLoops =
3503 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3504 SmallVector<Type> resultTypes;
3505 resultTypes.reserve(numExpectedLoops);
3506 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3507 "expected one loop type or as many as loops");
3508 if (loopTypes.size() == 1)
3509 resultTypes.append(numExpectedLoops, loopTypes[0]);
3510 else
3511 llvm::append_range(resultTypes, loopTypes);
3512 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
3513 if (scalableSizes.has_value())
3514 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3515 build(builder, result, /*tiled_linalg_op=*/target.getType(),
3516 /*loops=*/resultTypes,
3517 /*target=*/target,
3518 /*dynamic_sizes=*/dynamicTileSizes,
3519 /*static_sizes=*/staticTileSizesAttr,
3520 /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
3521 /*scalable_sizes=*/expandedScalableSizes);
3522}
3523
3524LogicalResult transform::TileUsingForOp::verify() {
3525 if (getMixedSizes().size() != getScalableSizes().size())
3526 return emitOpError("expected same number of sizes (")
3527 << getMixedSizes().size() << ") and scalable sizes ("
3528 << getScalableSizes().size() << ")";
3529 ArrayRef<int64_t> staticSizes = getStaticSizes();
3530 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3531 if (getLoops().size() != numExpectedLoops)
3532 return emitOpError("expected number of loops to tile (")
3533 << numExpectedLoops << ") to match number of `loops` results ("
3534 << getLoops().size() << ")";
3535 return success();
3536}
3537
3538DiagnosedSilenceableFailure
3539transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3540 TransformResults &transformResults,
3541 TransformState &state) {
3542 ArrayRef<int64_t> tileSizes = getStaticSizes();
3543
3544 SmallVector<Operation *> targets =
3545 llvm::to_vector(state.getPayloadOps(getTarget()));
3546 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3547 SmallVector<SmallVector<int64_t>> paramSizes;
3548 dynamicSizeProducers.reserve(getDynamicSizes().size());
3549 paramSizes.reserve(getDynamicSizes().size());
3550 for (Value transformValue : getDynamicSizes()) {
3551 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
3552 dynamicSizeProducers.push_back({});
3553 ArrayRef<Attribute> params = state.getParams(transformValue);
3554 paramSizes.push_back(llvm::map_to_vector(params, [](Attribute attr) {
3555 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3556 }));
3557
3558 if (paramSizes.back().size() != targets.size()) {
3559 DiagnosedSilenceableFailure diag =
3560 emitSilenceableError()
3561 << "expected as many parameter values ("
3562 << dynamicSizeProducers.back().size() << ") as target ops ("
3563 << targets.size() << ")";
3564 diag.attachNote(transformValue.getLoc()) << "for this parameter";
3565 return diag;
3566 }
3567
3568 continue;
3569 }
3570 paramSizes.push_back({});
3571 dynamicSizeProducers.push_back(
3572 llvm::to_vector(state.getPayloadOps(transformValue)));
3573
3574 if (dynamicSizeProducers.back().size() != targets.size()) {
3575 DiagnosedSilenceableFailure diag =
3576 emitSilenceableError()
3577 << "expected as many dynamic size-producing operations ("
3578 << dynamicSizeProducers.back().size() << ") as target ops ("
3579 << targets.size() << ")";
3580 diag.attachNote(transformValue.getLoc()) << "for this handle";
3581 return diag;
3582 }
3583
3584 for (Operation *op : dynamicSizeProducers.back()) {
3585 if (op->getNumResults() == 1 &&
3586 isa<IndexType>(op->getResult(0).getType())) {
3587 continue;
3588 }
3589
3590 DiagnosedSilenceableFailure diag =
3591 emitSilenceableError() << "expected sizes to be produced by ops "
3592 "with a single index-type result";
3593 diag.attachNote(op->getLoc()) << "size producer op";
3594 diag.attachNote(transformValue.getLoc()) << "for this handle";
3595 return diag;
3596 }
3597 }
3598
3599 SmallVector<Operation *> tiled;
3600 SmallVector<SmallVector<Operation *, 4>, 4> loops;
3601 loops.resize(getLoops().size());
3602 auto scalableSizes = getScalableSizes();
3603 for (auto [i, op] : llvm::enumerate(targets)) {
3604 auto tilingInterface = dyn_cast<TilingInterface>(op);
3605 if (!tilingInterface) {
3606 DiagnosedSilenceableFailure diag =
3607 emitSilenceableError()
3608 << "only ops implementing TilingInterface are supported";
3609 diag.attachNote(op->getLoc()) << "target op";
3610 return diag;
3611 }
3612 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3613 DiagnosedSilenceableFailure diag =
3614 emitSilenceableError()
3615 << "too many tiles provided, expected at most "
3616 << tilingInterface.getLoopIteratorTypes().size() << " found "
3617 << tileSizes.size();
3618 diag.attachNote(op->getLoc()) << "target op";
3619 return diag;
3620 }
3621
3622 scf::SCFTilingOptions tilingOptions;
3623 if (tileSizes.empty()) {
3624 tilingOptions.setTileSizeComputationFunction(
3625 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
3626 return {};
3627 });
3628 } else {
3629 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3630 Operation *) {
3631 SmallVector<OpFoldResult> sizes;
3632 sizes.reserve(tileSizes.size());
3633 unsigned dynamicIdx = 0;
3634
3635 for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3636 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3637 if (scalableSizes[ofrIdx]) {
3639 b, getLoc(), cast<IntegerAttr>(attr).getInt());
3640 Value vscale =
3641 vector::VectorScaleOp::create(b, getLoc(), b.getIndexType());
3642 sizes.push_back(
3643 arith::MulIOp::create(b, getLoc(), val, vscale).getResult());
3644 } else {
3645 sizes.push_back(attr);
3646 }
3647 continue;
3648 }
3649 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3650 ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3651 ++dynamicIdx;
3652 assert((dynamicSizes.empty() ^ params.empty()) &&
3653 "expected either dynamic sizes or parameters");
3654 if (!params.empty()) {
3655 sizes.push_back(b.getIndexAttr(params[index]));
3656 } else {
3657 sizes.push_back(dynamicSizes[index]->getResult(0));
3658 }
3659 }
3660 return sizes;
3661 });
3662 }
3663
3664 tilingOptions.setInterchange(getInterchange());
3665 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3666 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3667 if (failed(maybeTilingResult))
3669
3670 rewriter.replaceOp(op, maybeTilingResult->replacements);
3671
3672 tiled.append(maybeTilingResult->tiledOps);
3673 for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3674 loops[en2.index()].push_back(en2.value());
3675 }
3676
3677 transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3678 for (const auto &en : llvm::enumerate(loops))
3679 transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3680
3682}
3683
3684SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
3685 ValueRange dynamic = getDynamicSizes();
3686 ArrayRef<int64_t> tileSizes = getStaticSizes();
3687 SmallVector<OpFoldResult> results;
3688 results.reserve(tileSizes.size());
3689 unsigned dynamicPos = 0;
3690 Builder builder(getContext());
3691 for (int64_t size : tileSizes) {
3692 if (size == ShapedType::kDynamic) {
3693 results.push_back(dynamic[dynamicPos++]);
3694 } else {
3695 results.push_back(builder.getIndexAttr(size));
3696 }
3697 }
3698 return results;
3699}
3700
3701void transform::TileUsingForOp::getEffects(
3702 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3703 consumesHandle(getTargetMutable(), effects);
3704 onlyReadsHandle(getDynamicSizesMutable(), effects);
3705 producesHandle(getOperation()->getOpResults(), effects);
3706 modifiesPayload(effects);
3707}
3708
3709//===----------------------------------------------------------------------===//
3710// TileUsingForallOp
3711//===----------------------------------------------------------------------===//
3712
3713void transform::TileUsingForallOp::build(OpBuilder &builder,
3714 OperationState &result, Value target,
3715 ArrayRef<int64_t> staticTileSizes,
3716 transform::TileSizesSpec,
3717 ArrayAttr mapping) {
3718 return build(builder, result,
3719 /*target=*/target,
3720 /*mixedTileSizes=*/
3721 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3722 /*_=*/TileSizesSpec(),
3723 /*mapping=*/mapping);
3724}
3725
3726void transform::TileUsingForallOp::build(OpBuilder &builder,
3727 OperationState &result, Value target,
3728 ArrayRef<OpFoldResult> mixedTileSizes,
3729 transform::TileSizesSpec,
3730 ArrayAttr mapping) {
3731 SmallVector<int64_t> staticTileSizes;
3732 SmallVector<Value> dynamicTileSizes;
3733 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3734 // Call the default builder which sets up the proper operands segment sizes
3735 // attributes for multiple variadic operands. In the absence of this,
3736 // horrible bugs ensue.
3737 MLIRContext *ctx = builder.getContext();
3738 auto operationType = transform::AnyOpType::get(ctx);
3739 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3740 build(builder, result,
3741 /*resultTypes=*/TypeRange{operationType, operationType},
3742 /*target=*/target,
3743 /*num_threads=*/ValueRange{},
3744 /*tile_sizes=*/dynamicTileSizes,
3745 /*packed_num_threads=*/Value(),
3746 /*packed_tile_sizes=*/Value(),
3747 /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3748 /*static_tile_sizes=*/staticTileSizesAttr,
3749 /*mapping=*/mapping);
3750}
3751
3752void transform::TileUsingForallOp::build(OpBuilder &builder,
3753 OperationState &result, Value target,
3754 ArrayRef<int64_t> staticNumThreads,
3755 transform::NumThreadsSpec,
3756 ArrayAttr mapping) {
3757 return build(builder, result, target,
3758 getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3759 NumThreadsSpec(), mapping);
3760}
3761
3762void transform::TileUsingForallOp::build(OpBuilder &builder,
3763 OperationState &result, Value target,
3764 ArrayRef<OpFoldResult> mixedNumThreads,
3765 transform::NumThreadsSpec,
3766 ArrayAttr mapping) {
3767 SmallVector<int64_t> staticNumThreads;
3768 SmallVector<Value> dynamicNumThreads;
3769 dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3770 staticNumThreads);
3771 // Call the default builder which sets up the proper operands segment sizes
3772 // attributes for multiple variadic operands. In the absence of this,
3773 // horrible bugs ensue.
3774 MLIRContext *ctx = builder.getContext();
3775 auto operationType = transform::AnyOpType::get(ctx);
3776 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3777 build(builder, result,
3778 /*resultTypes=*/TypeRange{operationType, operationType},
3779 /*target=*/target,
3780 /*num_threads=*/dynamicNumThreads,
3781 /*tile_sizes=*/ValueRange{},
3782 /*packed_num_threads=*/Value(),
3783 /*packed_tile_sizes=*/Value(),
3784 /*static_num_threads=*/staticNumThreadsAttr,
3785 /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3786 /*mapping=*/mapping);
3787}
3788
3789/// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3790/// normalized upper bound.
3791static SmallVector<OpFoldResult>
3794 ArrayRef<OpFoldResult> steps) {
3795 AffineExpr s0, s1, s2;
3796 bindSymbols(rewriter.getContext(), s0, s1, s2);
3797 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3798 SmallVector<OpFoldResult> normalizedUbs;
3799 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3801 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3802 normalizedUbs.push_back(normalizedUb);
3803 }
3804 return normalizedUbs;
3805}
3806
3807/// When a loop is normalized, the uses of the induction variable within the
3808/// loop need to replaced with `original_lb + old_iv * original_step`.
3810 Location loc, ValueRange ivs,
3812 ArrayRef<OpFoldResult> steps) {
3813 AffineExpr s0, s1;
3814 AffineExpr d0;
3815 bindSymbols(rewriter.getContext(), s0, s1);
3816 bindDims(rewriter.getContext(), d0);
3817 AffineExpr denormExpr = s0 + d0 * s1;
3818 SmallVector<Value> denormalizedIvs;
3819
3820 for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3822 rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3823 denormalizedIvs.push_back(
3824 getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3825 }
3826 return denormalizedIvs;
3827}
3828
3829/// Given a `scf.forall` loop return a loop op with the loop bounds
3830/// normalized.
3831/// TODO: Replace this with a general utility to normalize `scf.forall`.
3832/// At the time of writing, this wasnt done since adding this to `scf`
3833/// dialect would disallow using of `affine.apply` operations due
3834/// to cyclic dependencies. To avoid churn in lit tests
3835/// with the change this was added with, defer that to a follow up.
3836static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3837 scf::ForallOp loop) {
3838 SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3839 SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3840 SmallVector<OpFoldResult> steps = loop.getMixedStep();
3841
3842 if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
3843 return loop;
3844 }
3845
3846 Location loc = loop.getLoc();
3847 SmallVector<OpFoldResult> normalizedUbs =
3848 normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3849 SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3850 rewriter.getIndexAttr(0));
3851 SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3852 rewriter.getIndexAttr(1));
3853
3854 auto normalizedForallOp = scf::ForallOp::create(
3855 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3856 loop.getOutputs(), loop.getMapping(),
3857 [](OpBuilder &, Location, ValueRange) {});
3858
3859 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3860 OpBuilder::InsertionGuard g(rewriter);
3861 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3862 rewriter.setInsertionPointToStart(normalizedLoopBlock);
3863
3864 SmallVector<Value> argValues =
3865 denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3866 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3867 normalizedForallOp.getRegionIterArgs().end());
3868 Block *origLoopBlock = loop.getBody();
3869 rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3870
3871 rewriter.replaceOp(loop, normalizedForallOp);
3872 return normalizedForallOp;
3873}
3874
3876 RewriterBase &rewriter, transform::TransformState &state,
3877 TransformOpInterface transformOp, Operation *target,
3878 ArrayRef<OpFoldResult> mixedNumThreads,
3879 ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3880 scf::SCFTilingResult &tilingResult) {
3881 // Transform all targets one by one.
3882 auto tileableOp = dyn_cast<TilingInterface>(target);
3883 if (!tileableOp) {
3885 transformOp.emitSilenceableError()
3886 << "only TilingInterface ops are supported";
3887 diag.attachNote(target->getLoc()) << "target op";
3888 return diag;
3889 }
3890 rewriter.setInsertionPoint(tileableOp);
3891 scf::SCFTilingOptions options;
3892 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3893 if (!mixedNumThreads.empty()) {
3894 options.setNumThreads(mixedNumThreads);
3895 } else {
3896 options.setTileSizes(mixedTileSizes);
3897 }
3898 if (mapping) {
3899 options.setMapping(mapping.value().getValue());
3900 }
3901 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3902 scf::tileUsingSCF(rewriter, tileableOp, options);
3903
3904 if (failed(maybeTilingResult))
3905 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3906
3907 rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3908
3909 tilingResult = *maybeTilingResult;
3910
3911 if (mixedNumThreads.empty()) {
3912 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3913 OpBuilder::InsertionGuard g(rewriter);
3914 rewriter.setInsertionPoint(generatedForallOp);
3915 scf::ForallOp normalizedForallOp =
3916 normalizeForallLoopOp(rewriter, generatedForallOp);
3917 tilingResult.loops.front() = normalizedForallOp;
3918 }
3919
3921}
3922
3923DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3925 transform::TransformResults &transformResults,
3927 auto transformOp = cast<TransformOpInterface>(getOperation());
3928
3929 // Result payload ops.
3931 SmallVector<Operation *> tiledOps;
3932
3933 // Unpack handles.
3934 SmallVector<OpFoldResult> mixedNumThreads;
3936 getPackedNumThreads()
3938 state, transformOp, mixedNumThreads, getPackedNumThreads())
3940 state, transformOp, mixedNumThreads, getMixedNumThreads());
3941 if (!status.succeeded())
3942 return status;
3943 SmallVector<OpFoldResult> mixedTileSizes;
3944 status = getPackedTileSizes()
3946 state, transformOp, mixedTileSizes, getPackedTileSizes())
3948 state, transformOp, mixedTileSizes, getMixedTileSizes());
3949 if (!status.succeeded())
3950 return status;
3951
3952 for (Operation *target : state.getPayloadOps(getTarget())) {
3953 scf::SCFTilingResult tilingResult;
3955 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3956 getMapping(), tilingResult);
3957 if (!diag.succeeded())
3958 return diag;
3959 tileOps.push_back(tilingResult.loops.front());
3960 tiledOps.append(tilingResult.tiledOps);
3961 }
3962
3963 transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3964 transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3965
3967}
3968
3969void transform::TileUsingForallOp::getEffects(
3970 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3971 consumesHandle(getTargetMutable(), effects);
3972 onlyReadsHandle(getTileSizesMutable(), effects);
3973 onlyReadsHandle(getNumThreadsMutable(), effects);
3974 onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3975 onlyReadsHandle(getPackedTileSizesMutable(), effects);
3976 producesHandle(getOperation()->getOpResults(), effects);
3977 modifiesPayload(effects);
3978}
3979
3980SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3981 Builder b(getContext());
3982 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3983}
3984
3985SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3986 Builder b(getContext());
3987 return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3988}
3989
3990LogicalResult TileUsingForallOp::verify() {
3991 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3992 static_cast<int>(getPackedNumThreads() != Value());
3993 if (numThreadsSpec > 1)
3994 return emitOpError(
3995 "num_threads and packed_num_threads are mutually exclusive");
3996 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3997 static_cast<int>(getPackedTileSizes() != Value());
3998 if (tileSizesSpec > 1)
3999 return emitOpError(
4000 "tile_sizes and packed_tile_sizes are mutually exclusive");
4001 if (numThreadsSpec == 0 && tileSizesSpec == 0)
4002 return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
4003 "must be specified");
4004 return success();
4005}
4006
4007//===----------------------------------------------------------------------===//
4008// VectorizeChildrenAndApplyPatternsOp
4009//===----------------------------------------------------------------------===//
4010
4011void transform::VectorizeChildrenAndApplyPatternsOp::build(
4012 OpBuilder &builder, OperationState &result, Value target,
4013 bool foldTypeExtensionsIntoContract, bool vectorizePadding,
4014 bool vectorizeExtract, bool flatten1DDepthwiseConv) {
4015 result.addOperands(target);
4016 if (foldTypeExtensionsIntoContract) {
4017 result.addAttribute(
4018 VectorizeChildrenAndApplyPatternsOp::
4019 getFoldTypeExtensionsIntoContractAttrName(result.name),
4020 builder.getUnitAttr());
4021 }
4022 if (vectorizePadding) {
4023 result.addAttribute(
4024 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
4025 result.name),
4026 builder.getUnitAttr());
4027 }
4028 if (vectorizeExtract) {
4029 result.addAttribute(
4030 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
4031 result.name),
4032 builder.getUnitAttr());
4033 }
4034 if (flatten1DDepthwiseConv) {
4035 result.addAttribute(
4036 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
4037 result.name),
4038 builder.getUnitAttr());
4039 }
4040 result.addTypes(transform::AnyOpType::get(builder.getContext()));
4041}
4042
4043namespace {
4044/// This is an helper only to call vectorize via a pattern inside of
4045/// VectorizeChildrenAndApplyPatternsOp::applyToOne.
4046struct VectorizationPattern : public RewritePattern {
4047 explicit VectorizationPattern(MLIRContext *context,
4048 bool vectorizeExtract = false,
4049 bool flattenConv = false)
4050 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
4051 vectorizeNDExtract(vectorizeExtract),
4052 flatten1DDepthwiseConv(flattenConv) {}
4053 LogicalResult matchAndRewrite(Operation *op,
4054 PatternRewriter &rewriter) const override {
4056 return rewriter.notifyMatchFailure(op,
4057 "Unsupported Op, cannot vectorize");
4058 FailureOr<VectorizationResult> vectorResults =
4059 vectorize(rewriter, op, /*inputVectorSizes=*/{},
4060 /*inputScalableVecDims=*/{}, vectorizeNDExtract,
4061 flatten1DDepthwiseConv);
4062 if (failed(vectorResults))
4063 return failure();
4064 rewriter.replaceOp(op, vectorResults->replacements);
4065 return success();
4066 }
4067
4068private:
4069 /// Controls whether to vectorize `tensor.extract` when the input tensor is
4070 /// rank >= 2.
4071 bool vectorizeNDExtract = false;
4072 /// Controls whether to "flatten" the channel dimension when vectorising 1D
4073 /// depthwise convolutions. This should lead to bette vectorization for
4074 /// tensors with a low number of channel dimensions.
4075 bool flatten1DDepthwiseConv = false;
4076};
4077} // namespace
4078
4079DiagnosedSilenceableFailure
4080transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
4081 transform::TransformRewriter &rewriter, Operation *target,
4082 transform::ApplyToEachResultList &results,
4083 transform::TransformState &state) {
4084 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
4085 auto diag = this->emitOpError("requires isolated-from-above targets");
4086 diag.attachNote(target->getLoc()) << "non-isolated target";
4088 }
4089
4090 MLIRContext *ctx = getContext();
4091 RewritePatternSet patterns(ctx);
4092 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
4093 getFlatten_1dDepthwiseConv());
4094
4095 if (!getDisableTransferPermutationMapLoweringPatterns())
4097
4098 if (!getDisableMultiReductionToContractPatterns())
4100
4102
4103 patterns.add<linalg::LinalgCopyVTRForwardingPattern,
4104 linalg::LinalgCopyVTWForwardingPattern>(ctx,
4105 /*benefit=*/2);
4106 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
4107 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
4109
4110 patterns.add<CopyVectorizationPattern>(ctx);
4111
4112 if (getFoldTypeExtensionsIntoContract())
4114
4115 if (getVectorizePadding()) {
4117 // This creates an alternative path for lowering tensor.pad - by
4118 // decomposing it into e.g. linalg.fill.
4120 }
4122
4123 TrackingListener listener(state, *this);
4124 if (failed(
4125 applyPatternsGreedily(target, std::move(patterns),
4126 GreedyRewriteConfig().setListener(&listener))))
4127 return emitDefaultDefiniteFailure(target);
4128
4129 results.push_back(target);
4131}
4132
4133//===----------------------------------------------------------------------===//
4134// VectorizeOp
4135//===----------------------------------------------------------------------===//
4136
4137DiagnosedSilenceableFailure transform::VectorizeOp::apply(
4138 transform::TransformRewriter &rewriter,
4139 mlir::transform::TransformResults &transformResults,
4140 mlir::transform::TransformState &state) {
4141 auto targets = state.getPayloadOps(getTarget());
4142 if (std::empty(targets))
4144 auto transformOp = cast<TransformOpInterface>(getOperation());
4145 SmallVector<int64_t> vectorSizes;
4146 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
4147 state, transformOp, getMixedVectorSizes(), vectorSizes);
4148 if (!status.succeeded())
4149 return status;
4150
4151 // TODO: Check that the correct number of vectorSizes was provided.
4152 for (Operation *target : targets) {
4154 return mlir::emitSilenceableFailure(target->getLoc())
4155 << "Unsupported Op, cannot vectorize";
4156 }
4157 FailureOr<VectorizationResult> vectorResults =
4158 linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
4159 getVectorizeNdExtract().value_or(false),
4160 /*flatten1DDepthwiseConv=*/false,
4161 getAssumeDynamicDimsMatchVecSizes().value_or(false),
4162 getCreateNamedContraction().value_or(false));
4163 if (failed(vectorResults)) {
4164 return mlir::emitSilenceableFailure(target->getLoc())
4165 << "Attempted to vectorize, but failed";
4166 }
4167 rewriter.replaceOp(target, vectorResults->replacements);
4168 }
4169
4171}
4172
4173void transform::VectorizeOp::getEffects(
4174 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4175 consumesHandle(getTargetMutable(), effects);
4176 onlyReadsHandle(getVectorSizesMutable(), effects);
4177 modifiesPayload(effects);
4178}
4179
4180SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
4181 OpBuilder b(getContext());
4182 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
4183}
4184
4185LogicalResult transform::VectorizeOp::verify() {
4186 if (getStaticVectorSizes().size() != getScalableSizes().size())
4187 return emitOpError("expected same number of vector sizes (")
4188 << getStaticVectorSizes().size() << ") and scalable sizes ("
4189 << getScalableSizes().size() << ")";
4190 return success();
4191}
4192
4193//===----------------------------------------------------------------------===//
4194// HoistRedundantVectorTransfersOp
4195//===----------------------------------------------------------------------===//
4196
4197DiagnosedSilenceableFailure
4198transform::HoistRedundantVectorTransfersOp::applyToOne(
4199 transform::TransformRewriter &rewriter, func::FuncOp target,
4200 transform::ApplyToEachResultList &results,
4201 transform::TransformState &state) {
4202 // WARNING: This hoisting does not model parallelism and is generally
4203 // incorrect when used on distributed loops with memref semantics!
4204 // TODO: obsolete and should be retired.
4205 linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
4206 results.push_back(target);
4208}
4209
4210//===----------------------------------------------------------------------===//
4211// HoistRedundantVectorBroadcastsOp
4212//===----------------------------------------------------------------------===//
4213
4214DiagnosedSilenceableFailure
4215transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4216 transform::TransformRewriter &rewriter, mlir::Operation *target,
4217 transform::ApplyToEachResultList &results,
4218 transform::TransformState &state) {
4219 rewriter.setInsertionPoint(target);
4221 results.push_back(target);
4223}
4224
4225//===----------------------------------------------------------------------===//
4226// ConvertConv2DToImg2ColOp.
4227//===----------------------------------------------------------------------===//
4228
4229DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
4230 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4231 transform::ApplyToEachResultList &results,
4232 transform::TransformState &state) {
4233 rewriter.setInsertionPoint(target);
4234 auto maybeTransformed =
4236 target)
4237 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4238 return rewriteInIm2Col(rewriter, op);
4239 })
4240 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4241 return rewriteInIm2Col(rewriter, op);
4242 })
4243 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4244 return rewriteInIm2Col(rewriter, op);
4245 })
4246 .Case([&](linalg::Conv2DNchwFchwOp op) {
4247 return rewriteInIm2Col(rewriter, op);
4248 })
4249 .Default([&](Operation *op) {
4250 return rewriter.notifyMatchFailure(op, "not supported");
4251 });
4252 if (failed(maybeTransformed))
4253 return emitDefaultSilenceableFailure(target);
4254 // Handle to the operation producing the img2col tensor.
4255 results.push_back(maybeTransformed->first);
4256 // Handle to the operation that replaces the original convolution.
4257 results.push_back(maybeTransformed->second);
4259}
4260
4261//===----------------------------------------------------------------------===//
4262// FlattenElementwiseLinalgOp.
4263//===----------------------------------------------------------------------===//
4264
4265DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4266 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4267 transform::ApplyToEachResultList &results,
4268 transform::TransformState &state) {
4269 rewriter.setInsertionPoint(target);
4270 if (!isElementwise(target))
4271 return mlir::emitSilenceableFailure(target->getLoc())
4272 << "only elementwise flattening is supported";
4273
4274 // If rank <= 1, do nothing
4275 if (target.getNumLoops() <= 1) {
4276 results.push_back(target);
4278 }
4279
4280 // Attempt to flatten all dims to one.
4281 ReassociationIndices reassociation(target.getNumLoops());
4282 std::iota(reassociation.begin(), reassociation.end(), 0);
4283 auto maybeFlattened =
4284 collapseOpIterationDims(target, reassociation, rewriter);
4285 if (failed(maybeFlattened))
4286 return mlir::emitSilenceableFailure(target->getLoc())
4287 << "attempted to flatten, but failed";
4288 results.push_back(maybeFlattened->collapsedOp);
4289 rewriter.replaceOp(target, maybeFlattened->results);
4291}
4292
4293//===----------------------------------------------------------------------===//
4294// TransposeConv2DOp
4295//===----------------------------------------------------------------------===//
4296
4297DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4298 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4299 transform::ApplyToEachResultList &results,
4300 transform::TransformState &state) {
4301 rewriter.setInsertionPoint(target);
4302 auto maybeTransformed =
4304 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4305 return transposeConv2D(rewriter, op);
4306 })
4307 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4308 return transposeConv2D(rewriter, op);
4309 })
4310 .Default([&](Operation *op) {
4311 return rewriter.notifyMatchFailure(op, "not supported");
4312 });
4313 if (failed(maybeTransformed))
4314 return emitDefaultSilenceableFailure(target);
4315 // Handle to the new Conv2D operation with transposed filters
4316 results.push_back(*maybeTransformed);
4318}
4319
4320//===----------------------------------------------------------------------===//
4321// TransposeMatmulOp
4322//===----------------------------------------------------------------------===//
4323
4324DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4325 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4326 transform::ApplyToEachResultList &results,
4327 transform::TransformState &state) {
4328 rewriter.setInsertionPoint(target);
4329 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4330 auto maybeTransformed =
4332 .Case([&](linalg::MatmulOp op) {
4333 return transposeMatmul(rewriter, op, transposeLHS);
4334 })
4335 .Case([&](linalg::BatchMatmulOp op) {
4336 return transposeBatchMatmul(rewriter, op, transposeLHS);
4337 })
4338 .Default(failure());
4339 if (failed(maybeTransformed))
4340 return emitSilenceableFailure(target->getLoc()) << "not supported";
4341 // Handle to the new Matmul operation with transposed filters
4342 results.push_back(*maybeTransformed);
4344}
4345
4346//===----------------------------------------------------------------------===//
4347// InsertSliceToCopyOp
4348//===----------------------------------------------------------------------===//
4349template <typename OpTy>
4350static DiagnosedSilenceableFailure
4351doit(RewriterBase &rewriter, OpTy target,
4354 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4355 tensor::ParallelInsertSliceOp>() &&
4356 "wrong op type");
4357
4358 if (auto copySource =
4359 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4360 results.push_back(copySource);
4362 }
4363
4364 // If we are inside a `ParallelCombiningOp` region, temporarily set the
4365 // insertion point outside: only ops implementing ParallelCombiningOpInterface
4366 // are allowed in there.
4367 if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
4368 rewriter.setInsertionPoint(target->getParentOp());
4369
4370 Value extracted = tensor::ExtractSliceOp::create(
4371 rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
4372 target.getMixedSizes(), target.getMixedStrides());
4373 Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
4374 target.getSource(), extracted)
4375 .getResult(0);
4376 // Reset the insertion point.
4377 rewriter.setInsertionPoint(target);
4378 rewriter.replaceOpWithNewOp<OpTy>(
4379 target, copied, target.getDest(), target.getMixedOffsets(),
4380 target.getMixedSizes(), target.getMixedStrides());
4381
4382 results.push_back(copied.getDefiningOp());
4384}
4385
4386DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4387 transform::TransformRewriter &rewriter, Operation *targetOp,
4388 transform::ApplyToEachResultList &results,
4389 transform::TransformState &state) {
4390
4391 rewriter.setInsertionPoint(targetOp);
4392 if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4393 return doit(rewriter, target, results, state);
4394 if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4395 return doit(rewriter, target, results, state);
4396
4397 DiagnosedSilenceableFailure diag =
4398 emitSilenceableError()
4399 << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4400 diag.attachNote(targetOp->getLoc()) << "target op";
4401 return diag;
4402}
4403
4404//===----------------------------------------------------------------------===//
4405// MapCopyToThreadsOp
4406//===----------------------------------------------------------------------===//
4407
4408DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4409 transform::TransformRewriter &rewriter, Operation *target,
4410 transform::ApplyToEachResultList &results,
4411 transform::TransformState &state) {
4412 // Check if the op is supported.
4413 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4414 DiagnosedSilenceableFailure diag =
4415 emitSilenceableError()
4416 << "only linalg.copy and tensor.pad target ops are supported";
4417 diag.attachNote(target->getLoc()) << "target op";
4418 return diag;
4419 }
4420 assert(target->getNumResults() == 1 && "expected single result");
4421 auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
4422 if (!resultShapedType.hasStaticShape()) {
4423 DiagnosedSilenceableFailure diag =
4424 emitSilenceableError()
4425 << "only statically sized ops of rank <= 3 are supported";
4426 diag.attachNote(target->getLoc()) << "target op";
4427 return diag;
4428 }
4429
4430 // Conservatively set the minimum viable desired bitwidth alignment.
4431 int64_t desiredBitAlignment = getDesiredBitAlignment();
4432 int64_t eltBitwidth =
4433 resultShapedType.getElementType().getIntOrFloatBitWidth();
4434 if (desiredBitAlignment % eltBitwidth != 0) {
4435 desiredBitAlignment = eltBitwidth;
4436 }
4437
4438 gpu::CopyMappingInfo mapping(
4439 /*ctx=*/getContext(),
4440 /*totalNumThreads=*/getTotalNumThreads(),
4441 /*alignment=*/desiredBitAlignment,
4442 /*sizes=*/resultShapedType.getShape(),
4443 /*favorPredication=*/false,
4444 /*elementalBitwidth=*/
4445 resultShapedType.getElementType().getIntOrFloatBitWidth());
4446 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4447 DiagnosedSilenceableFailure diag =
4448 emitSilenceableError()
4449 << "too few threads to map copy op to threads on the most minor "
4450 "dimension, given alignment and vector size constraints, try "
4451 "smaller tile size of mapping to more threads";
4452 diag.attachNote(target->getLoc()) << "target op";
4453 return diag;
4454 }
4455
4456 // OpBuilder only used to compute attributes.
4457 OpBuilder b(getContext());
4458 scf::SCFTilingResult tilingResult;
4459 DiagnosedSilenceableFailure diag = tileToForallOpImpl(
4460 /*rewriter=*/rewriter,
4461 /*state=*/state,
4462 /*transformOp=*/*this,
4463 /*target=*/target,
4464 /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
4465 /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
4466 /*mapping=*/b.getArrayAttr(mapping.threadMapping),
4467 /*tilingResult=*/tilingResult);
4468 if (!diag.succeeded())
4469 return diag;
4470
4471 results.push_back(tilingResult.loops.front());
4472 for (auto *op : tilingResult.tiledOps)
4473 results.push_back(op);
4475}
4476
4477//===----------------------------------------------------------------------===//
4478// WinogradConv2DOp
4479//===----------------------------------------------------------------------===//
4480
4481DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4482 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4483 transform::ApplyToEachResultList &results,
4484 transform::TransformState &state) {
4485 rewriter.setInsertionPoint(target);
4486 FailureOr<Operation *> maybeTransformed = failure();
4487 bool supported = TypeSwitch<Operation *, bool>(target)
4488 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4489 maybeTransformed =
4490 winogradConv2D(rewriter, op, getFmr());
4491 return true;
4492 })
4493 .Default([&](Operation *op) { return false; });
4494
4495 if (!supported) {
4496 return emitSilenceableError()
4497 << "this operation is not supported to convert to Winograd Conv2D";
4498 }
4499
4500 if (failed(maybeTransformed)) {
4501 return emitSilenceableError() << "apply Winograd Conv2D failed";
4502 }
4503
4504 results.push_back(*maybeTransformed);
4506}
4507
4508DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4509 transform::TransformRewriter &rewriter, Operation *target,
4510 transform::ApplyToEachResultList &results,
4511 transform::TransformState &state) {
4512 rewriter.setInsertionPoint(target);
4513 FailureOr<Operation *> maybeTransformed = failure();
4514 bool supported =
4516 .Case([&](linalg::WinogradFilterTransformOp op) {
4517 maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
4518 return true;
4519 })
4520 .Case([&](linalg::WinogradInputTransformOp op) {
4521 maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
4522 return true;
4523 })
4524 .Case([&](linalg::WinogradOutputTransformOp op) {
4525 maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
4526 return true;
4527 })
4528 .Default(false);
4529
4530 if (!supported) {
4531 DiagnosedSilenceableFailure diag =
4532 emitSilenceableError()
4533 << "this operation is not supported to decompose into other operations";
4534 diag.attachNote(target->getLoc()) << "target op";
4535 return diag;
4536 }
4537
4538 if (failed(maybeTransformed)) {
4539 DiagnosedSilenceableFailure diag =
4540 emitSilenceableError() << "decompose Winograd operations failed";
4541 diag.attachNote(target->getLoc()) << "target op";
4542 return diag;
4543 }
4544
4545 results.push_back(*maybeTransformed);
4547}
4548
4549#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4550
4551#define GET_OP_CLASSES
4552#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized. TODO: Replace this with a g...
#define DOWNSCALE_NORMAL(a, b)
ArrayAttr()
static bool mayBeRead(OpOperand &operand)
Return true if the operand may be read from by its owner.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static bool sameOrEquivalentIterArg(Value src, Value dst)
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tileSizes, Type)
static DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static llvm::ManagedStatic< PassManagerOptions > options
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Base type for affine expression.
Definition AffineExpr.h:68
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition Builders.cpp:310
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the error.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:339
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition Builders.h:318
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:454
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getOpResult(unsigned idx)
Definition Operation.h:450
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:563
void setOperand(unsigned idx, Value value)
Definition Operation.h:380
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:589
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:234
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:426
result_type_range getResultTypes()
Definition Operation.h:457
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:292
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:826
user_range getUsers()
Returns a range of all users.
Definition Operation.h:902
result_range getOpResults()
Definition Operation.h:449
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation, bool PoisonPaddingOk=false)
Patterns to bubble up or down data layout ops across other operations.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
Definition Padding.cpp:244
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice + copy.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Canonicalization patterns relevant to apply after tiling patterns.
Definition Tiling.cpp:866
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition Transforms.h:489
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors and memref.
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:217
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition Tiling.cpp:236
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition Tiling.cpp:156
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition Tiling.cpp:106
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition Hoisting.cpp:89
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition Hoisting.cpp:198
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition Split.cpp:68
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition Tiling.cpp:262
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition SCF.cpp:675
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:59
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SmallVector< IntTy > extractFromIntegerArrayAttr(Attribute attr)
Extract integer values from the assumed ArrayAttr of IntegerAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition Builders.h:287
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition Transforms.h:521
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition Transforms.h:409
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition Transforms.h:456
LinalgPromotionOptions & setAlignment(unsigned align)
Definition Transforms.h:433
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition Transforms.h:440
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition Transforms.h:398
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition Transforms.h:420
LinalgPromotionOptions & setUseOriginalSubviewSize(bool originalSize)
Definition Transforms.h:427
LinalgPromotionOptions & setUseAlloca(bool use)
Definition Transforms.h:446
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition Transforms.h:466