MLIR 22.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
12
41#include "mlir/Support/LLVM.h"
43#include "llvm/ADT/STLExtras.h"
44#include "llvm/ADT/ScopeExit.h"
45#include "llvm/ADT/SmallPtrSet.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/DebugLog.h"
48#include "llvm/Support/LogicalResult.h"
49#include <type_traits>
50
51using namespace mlir;
52using namespace mlir::linalg;
53using namespace mlir::transform;
54
55#define DEBUG_TYPE "linalg-transforms"
56
57/// Attempts to apply the pattern specified as template argument to the given
58/// operation. The pattern is expected to have a `returningMatchAndRewrite`
59/// function that returns the "main" result or failure. Returns failure if the
60/// pattern failed to apply. Extra arguments are forwarded to the pattern
61/// constructor.
62template <typename PatternTy, typename... Args>
63static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
64 // Check if the given operation has the type expected by the pattern.
65 using OpTy = typename llvm::function_traits<
66 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
67 auto op = dyn_cast<OpTy>(operation);
68 if (!op)
69 return failure();
70
71 // Apply the pattern directly to the op.
72 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
73 // We want to discourage direct use of PatternRewriter in APIs but In this
74 // very specific case, an IRRewriter is not enough.
75 PatternRewriter rewriter(operation->getContext());
76 rewriter.setInsertionPoint(operation);
77 auto result = pattern.returningMatchAndRewrite(op, rewriter);
78 if (failed(result))
79 return failure();
80 return cast<LinalgOp>(result->getOperation());
81}
82
83/// Assuming that `ofr` is an index attr or a param of index type
84/// or a transform dialect handle mapped to exactly one op
85/// with one index result, return that value.
87 transform::TransformState &state, TransformOpInterface transformOp,
89 for (OpFoldResult ofr : ofrs) {
90 if (auto attr = dyn_cast<Attribute>(ofr)) {
91 if (!isa<IntegerAttr>(attr))
92 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
93 result.push_back(ofr);
94 continue;
95 }
96
97 Value transformValue = cast<Value>(ofr);
98 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
99 ArrayRef<Attribute> params = state.getParams(transformValue);
100 if (params.size() != 1)
101 return transformOp.emitDefiniteFailure()
102 << "requires exactly one parameter associated";
103 result.push_back(params[0]);
104 continue;
105 }
106
107 auto payloadOps = state.getPayloadOps(transformValue);
108 if (!llvm::hasSingleElement(payloadOps)) {
110 transformOp.emitSilenceableError()
111 << "handle must be mapped to exactly one payload op";
112 diag.attachNote(transformValue.getLoc())
113 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
114 return diag;
115 }
116
117 Operation *op = *payloadOps.begin();
118 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
120 transformOp.emitSilenceableError()
121 << "payload op must have exactly 1 index result";
122 diag.attachNote(op->getLoc())
123 << "has " << op->getNumResults() << " results";
124 return diag;
125 }
126 result.push_back(op->getResult(0));
127 }
128
130}
131
132// Given a list of params that are index attrs or a list of OpFoldResults
133// that are either index attrs or op handles, return a list of OpFoldResults
134// of index attrs or a list of OpFoldResults where all op handles are
135// replaced with the first (and only) OpResult of that payload op.
136// (There must be exactly one parameter associated with the AnyParamType or
137// one mapped payload op which must have exactly one index result.)
139 transform::TransformState &state, TransformOpInterface transformOp,
140 SmallVector<OpFoldResult> &result, Value packedHandle) {
141 if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
142 ArrayRef<Attribute> params = state.getParams(packedHandle);
143 for (auto param : params) {
144 if (!isa<IntegerAttr>(param))
145 return transformOp.emitDefiniteFailure()
146 << "expected the parameter to be associated with an integer "
147 "attribute";
148 result.push_back(param);
149 }
151 }
152
153 for (Operation *op : state.getPayloadOps(packedHandle)) {
154 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
156 transformOp.emitSilenceableError()
157 << "payload op must have exactly 1 index result";
158 diag.attachNote(op->getLoc())
159 << "has " << op->getNumResults() << " results";
160 return diag;
161 }
162 result.push_back(op->getResult(0));
163 }
164
166}
167
168/// When possible, converts each `OpFoldResult` in `mixedResult` to
169/// an integer if the value can be statically inferred. If a result
170/// is a `Value` then it must be either a `ParamType` or a handle
171/// to an a constant like op.
173 TransformState &state, TransformOpInterface &transformOp,
174 ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
175 for (OpFoldResult paramOrHandle : mixedResults) {
176 if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
177 reified.push_back(cast<IntegerAttr>(attr).getInt());
178 continue;
179 }
180 if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
181 ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
182 if (params.size() != 1)
183 return transformOp.emitSilenceableError() << "expected a single param";
184 reified.push_back(
185 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
186 continue;
187 }
188
189 Value handle = cast<Value>(paramOrHandle);
190 if (!isa<TransformHandleTypeInterface>(handle.getType()))
191 return transformOp.emitSilenceableError() << "unexpected value handle";
192 auto payload = state.getPayloadOps(handle);
193 if (!llvm::hasSingleElement(payload))
194 return transformOp.emitSilenceableError()
195 << "requires param or handle that is mapped to 1 payload op";
196
197 Operation *paramOrHandlePayloadOp = *payload.begin();
198 if (paramOrHandlePayloadOp->getNumResults() != 1 ||
199 !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
200 return transformOp.emitSilenceableError()
201 << "requires param or handle to be result of op with 1 index "
202 "result";
203 }
204
205 IntegerAttr attr;
206 if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
207 return transformOp.emitSilenceableError()
208 << "requires param or handle to be the result of a constant like "
209 "op";
210
211 reified.push_back(attr.getInt());
212 }
214}
215
216//===----------------------------------------------------------------------===//
217// Apply...PatternsOp
218//===----------------------------------------------------------------------===//
219
220void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
223}
224
225void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
228}
229
230void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
233}
234
235void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
239}
240
241void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
244 options.rankReductionStrategy =
247}
248
249void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
252}
253
254void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
257}
258
259void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
262}
263
264void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
267}
268
269void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
272}
273
274//===----------------------------------------------------------------------===//
275// BufferizeToAllocationOp
276//===----------------------------------------------------------------------===//
277
278namespace {
279class NewOpsListener : public RewriterBase::ForwardingListener {
280public:
282
283 SmallVector<Operation *> getNewOps() const {
284 return SmallVector<Operation *>(newOps.begin(), newOps.end());
285 }
286
287private:
288 void notifyOperationInserted(Operation *op,
289 OpBuilder::InsertPoint previous) override {
290 ForwardingListener::notifyOperationInserted(op, previous);
291 // We only care about newly created ops.
292 if (previous.isSet())
293 return;
294 auto inserted = newOps.insert(op);
295 (void)inserted;
296 assert(inserted.second && "expected newly created op");
297 }
298
299 void notifyOperationErased(Operation *op) override {
300 ForwardingListener::notifyOperationErased(op);
301 op->walk([&](Operation *op) { newOps.erase(op); });
302 }
303
305};
306} // namespace
307
308DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
311 // Attach listener to keep track of newly created ops.
312 OpBuilder::Listener *previousListener = rewriter.getListener();
313 auto resetListener =
314 llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
315 NewOpsListener newOpsListener(previousListener);
316 rewriter.setListener(&newOpsListener);
317
319 if (getMemcpyOp() == "bufferization.materialize_in_destination") {
320 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
321 MaterializeInDestination;
322 } else if (getMemcpyOp() == "memref.copy") {
323 options.memcpyOp =
325 } else if (getMemcpyOp() == "linalg.copy") {
326 options.memcpyOp =
328 } else {
329 llvm_unreachable("invalid memcpy op");
330 }
331 if (getAllocOp() == "memref.alloc") {
332 options.allocOp =
334 } else if (getAllocOp() == "memref.alloca") {
335 options.allocOp =
337 } else {
338 llvm_unreachable("invalid alloc op");
339 }
340 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
341 options.emitDealloc = getEmitDealloc();
342
343 // Bufferize ops.
344 Attribute memorySpace =
345 getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
346 SmallVector<Value> allocatedBuffers;
347 for (Operation *op : state.getPayloadOps(getTarget())) {
348 Value buffer =
349 linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
350 if (!buffer) {
351 DiagnosedSilenceableFailure diag = emitSilenceableError()
352 << "failed to bufferize operation";
353 diag.attachNote(op->getLoc()) << "target payload op";
354 return diag;
355 }
356 allocatedBuffers.push_back(buffer);
357 }
358
359 // Set results.
360 results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
361 results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
363}
364
365void transform::BufferizeToAllocationOp::getEffects(
367 if (getBufferizeDestinationOnly()) {
368 // The destination is replaced with a newly allocated buffer, but the op
369 // itself remains in place.
370 onlyReadsHandle(getTargetMutable(), effects);
371 } else {
372 consumesHandle(getTargetMutable(), effects);
373 }
374 producesHandle(getOperation()->getOpResults(), effects);
375 modifiesPayload(effects);
376}
377
378LogicalResult transform::BufferizeToAllocationOp::verify() {
379 if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
380 getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
381 return emitOpError() << "unsupported memcpy op";
382 if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
383 return emitOpError() << "unsupported alloc op";
384 return success();
385}
386
387//===----------------------------------------------------------------------===//
388// PromoteTensorOp
389//===----------------------------------------------------------------------===//
390
391/// Return true if the operand may be read from by its owner. This is currently
392/// very conservative and only looks inside linalg operations to prevent
393/// unintentional data loss.
394static bool mayBeRead(OpOperand &operand) {
395 auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
396
397 // Be conservative about ops we cannot analyze deeper.
398 if (!linalgOp)
399 return true;
400
401 // Look inside linalg ops.
402 Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
403 return !blockArgument.use_empty();
404}
405
406/// Return true if the value may be read through any of its uses.
407static bool mayBeRead(Value value) {
408 // If the value has a reference semantics, it
409 // may be read through any alias...
410 if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
411 return true;
412 return llvm::any_of(value.getUses(),
413 static_cast<bool (&)(OpOperand &)>(mayBeRead));
414}
415
417transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
420 SmallVector<Value> promoted;
421 for (Value tensor : state.getPayloadValues(getTensor())) {
422 auto type = dyn_cast<RankedTensorType>(tensor.getType());
423 if (!type) {
424 return emitSilenceableError() << "non-tensor type: " << tensor;
425 }
426
427 Operation *definingOp = tensor.getDefiningOp();
428 if (definingOp)
429 rewriter.setInsertionPointAfter(definingOp);
430 else
431 rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
432
433 // Check this before we emit operations using this value.
434 bool needsMaterialization = mayBeRead(tensor);
435
436 SmallVector<Value> dynamicDims;
438 for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
439 if (!ShapedType::isDynamic(dim))
440 continue;
441 Value cst =
442 arith::ConstantIndexOp::create(rewriter, tensor.getLoc(), pos);
443 auto dimOp =
444 tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst);
445 preservedOps.insert(dimOp);
446 dynamicDims.push_back(dimOp);
447 }
448 auto allocation = bufferization::AllocTensorOp::create(
449 rewriter, tensor.getLoc(), type, dynamicDims);
450 // Set memory space if provided.
451 if (getMemorySpaceAttr())
452 allocation.setMemorySpaceAttr(getMemorySpaceAttr());
453 Value allocated = allocation;
454
455 // Only insert a materialization (typically bufferizes to a copy) when the
456 // value may be read from.
457 if (needsMaterialization) {
458 auto copy = bufferization::MaterializeInDestinationOp::create(
459 rewriter, tensor.getLoc(), tensor, allocated);
460 preservedOps.insert(copy);
461 promoted.push_back(copy.getResult());
462 } else {
463 promoted.push_back(allocated);
464 }
465 rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
466 }
467 results.setValues(cast<OpResult>(getPromoted()), promoted);
469}
470
471void transform::PromoteTensorOp::getEffects(
473 transform::onlyReadsHandle(getTensorMutable(), effects);
474 transform::producesHandle(getOperation()->getOpResults(), effects);
476}
477
478//===----------------------------------------------------------------------===//
479// DecomposeOp
480//===----------------------------------------------------------------------===//
481
483transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
484 LinalgOp target,
487#define DOWNSCALE(trans) \
488 { \
489 FailureOr<LinalgOp> res = tryApply<trans>(target); \
490 if (succeeded(res)) { \
491 results.push_back(*res); \
492 return DiagnosedSilenceableFailure::success(); \
493 } \
494 }
495
496#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
497#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
498
499 DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
500 DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
501 DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
502 DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
503 DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
504 DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
505 DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
506 DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
507 DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
510#undef DOWNSCALE_NORMAL
511#undef DOWNSCALE_CALL
512#undef DOWNSCALE
513 return emitDefaultSilenceableFailure(target);
514}
515
516//===----------------------------------------------------------------------===//
517// DecomposeInterfaceOp
518//===----------------------------------------------------------------------===//
519
520// Decompose the target operation if it implements the AggregatedOpInterface.
521// Push the decomposed operations (the ones that replaces the values produced by
522// \p target) in the `results`.
523DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
527 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
528 if (!decomposableOp) {
530 "payload is not a decomposable op"));
531 return emitDefaultSilenceableFailure(target);
532 }
533
534 FailureOr<SmallVector<Value>> maybeNewResults =
535 decomposableOp.decomposeOperation(rewriter);
536 if (failed(maybeNewResults))
537 return emitDefaultSilenceableFailure(target);
538
539 rewriter.replaceOp(decomposableOp, *maybeNewResults);
540 for (Value val : *maybeNewResults) {
541 Operation *definition = val.getDefiningOp();
542 if (definition)
543 results.push_back(definition);
544 }
546}
547
548//===----------------------------------------------------------------------===//
549// EliminateLinalgOpAnchoredEmptyTensorsOp
550//===----------------------------------------------------------------------===//
551
552void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
554 onlyReadsHandle(getTargetMutable(), effects);
555 modifiesPayload(effects);
556}
557
559transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
560 transform::TransformRewriter &rewriter, TransformResults &transformResults,
561 TransformState &state) {
563 options.allowReturnAllocsFromLoops = true;
564
565 for (Operation *target : state.getPayloadOps(getTarget())) {
567 if (failed(analyzeOp(target, state)))
568 return mlir::emitSilenceableFailure(target->getLoc())
569 << "failed to analyze op";
571 rewriter, target, state)))
572 return mlir::emitSilenceableFailure(target->getLoc())
573 << "failed to eliminate LinalgOp anchored tensor.empty ops";
574 }
576}
577
578//===----------------------------------------------------------------------===//
579// FuseOp
580//===----------------------------------------------------------------------===//
581
582void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
583 TypeRange loopTypes, Value target,
584 ArrayRef<int64_t> staticTileSizes,
585 ArrayRef<int64_t> staticTileInterchange,
586 bool applyCleanup, bool useForall) {
587 return build(
588 builder, result, loopTypes,
589 /*target=*/target,
590 /*mixedTileSizes=*/
591 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
592 /*mixedTileInterchange=*/
593 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
594 applyCleanup, useForall);
595}
596
597void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
598 Value target, ArrayRef<int64_t> staticTileSizes,
599 ArrayRef<int64_t> staticTileInterchange,
600 bool applyCleanup, bool useForall) {
601 return build(
602 builder, result,
603 /*target=*/target,
604 /*mixedTileSizes=*/
605 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
606 /*mixedTileInterchange=*/
607 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
608 applyCleanup, useForall);
609}
610
611void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
613 ArrayRef<OpFoldResult> mixedTileSizes,
614 ArrayRef<OpFoldResult> mixedTileInterchange,
615 bool applyCleanup, bool useForall) {
616 // Loop types are automaticaly splat by the callee, setting up one is
617 // enough.
618 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
619 build(builder, result, loopTypes, target, mixedTileSizes,
620 mixedTileInterchange, applyCleanup, useForall);
621}
622
623void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
624 TypeRange loopTypes, Value target,
625 ArrayRef<OpFoldResult> mixedTileSizes,
626 ArrayRef<OpFoldResult> mixedTileInterchange,
627 bool applyCleanup, bool useForall) {
628 SmallVector<int64_t> staticTileSizes;
629 SmallVector<Value> dynamicTileSizes;
630 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
631 SmallVector<int64_t> staticTileInterchange;
632 SmallVector<Value> dynamicTileInterchange;
633 dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
634 staticTileInterchange);
635 // Call the default builder which sets up the proper operands segment sizes
636 // attributes for multiple variadic operands. In the absence of this,
637 // horrible bugs ensue.
638 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
639 auto staticTileInterchangeAttr =
640 builder.getDenseI64ArrayAttr(staticTileInterchange);
641 unsigned numExpectedLoops =
642 useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
643 SmallVector<Type> resultTypes;
644 resultTypes.reserve(numExpectedLoops);
645 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
646 "expected one loop type or as many as loops");
647 if (loopTypes.size() == 1)
648 resultTypes.append(numExpectedLoops, loopTypes[0]);
649 else
650 llvm::append_range(resultTypes, loopTypes);
651 build(builder, result, /*transformed=*/target.getType(),
652 /*loops=*/resultTypes,
653 /*target=*/target,
654 /*tile_sizes=*/dynamicTileSizes,
655 /*tile_interchange=*/dynamicTileInterchange,
656 /*static_tile_sizes=*/staticTileSizesAttr,
657 /*static_tile_interchange=*/staticTileInterchangeAttr,
658 /*apply_cleanup=*/applyCleanup,
659 /*use_forall=*/useForall);
660}
661
662/// Apply a tiling transformation to all payload ops and store both the
663/// tiled operation as well as the created tile loops.
664template <typename Range>
665static LogicalResult applyTilingToAll(
666 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
667 unsigned numLoops, transform::TransformResults &transformResults,
668 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
669 applyFn) {
670 SmallVector<Operation *> tiledLinalgOps;
671 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
672
673 for (Operation *target : payloadOps) {
674 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
675 if (!tilingInterfaceOp)
676 return transformOp->emitError("only TilingInterface ops are supported");
677
678 rewriter.setInsertionPoint(target);
679 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
680 applyFn(tilingInterfaceOp);
681 if (failed(tiledResults))
682 return failure();
683
684 // Perform the replacement of tiled and fused values.
685 SmallVector<Operation *> opsToReplace{target};
686 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
687 for (Operation *toReplace : opsToReplace) {
688 for (OpResult res : toReplace->getResults())
689 if (auto replacement = tiledResults->replacements.lookup(res))
690 rewriter.replaceAllUsesWith(res, replacement);
691 if (toReplace->use_empty()) {
692 rewriter.eraseOp(toReplace);
693 }
694 }
695
696 // Report back the relevant handles to the transform op.
697 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
698 assert(tiledResults->loops.size() == numLoops &&
699 "Mismatched number of loops, tile and fuse transform should have "
700 "failed");
701 for (unsigned int i = 0; i < numLoops; ++i)
702 loopOps[i].push_back(tiledResults->loops[i]);
703 }
704
705 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
706 for (unsigned int i = 0; i < numLoops; ++i)
707 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
708
709 return success();
710}
711
713transform::FuseOp::apply(transform::TransformRewriter &rewriter,
714 mlir::transform::TransformResults &transformResults,
716 auto transformOp = cast<TransformOpInterface>(getOperation());
717
718 SmallVector<int64_t> tileSizes;
720 state, transformOp, getMixedTileSizes(), tileSizes);
721 if (!status.succeeded())
722 return status;
723 SmallVector<int64_t> tileInterchange;
725 state, transformOp, getMixedTileInterchange(), tileInterchange);
726 if (!status.succeeded())
727 return status;
728
729 scf::SCFTilingOptions tilingOptions;
730 tilingOptions.interchangeVector = tileInterchange;
731 bool useForall = getUseForall();
732 tilingOptions.setLoopType(useForall
733 ? scf::SCFTilingOptions::LoopType::ForallOp
734 : scf::SCFTilingOptions::LoopType::ForOp);
735 SmallVector<OpFoldResult> tileSizesOfr =
736 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
737 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
738 scf::SCFTileAndFuseOptions tileAndFuseOptions;
739 tileAndFuseOptions.tilingOptions = tilingOptions;
740
741 if (getApplyCleanup()) {
742 MLIRContext *context = rewriter.getContext();
744 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
747 tileAndFuseOptions.cleanupPatterns = std::move(patterns);
748 }
749
750 size_t numLoops =
751 useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
752 LogicalResult result = applyTilingToAll(
753 rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
754 transformResults,
755 [&](TilingInterface tilingInterfaceOp)
756 -> FailureOr<scf::SCFTileAndFuseResult> {
757 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
758 tileAndFuseOptions);
759 });
762}
763
764LogicalResult transform::FuseOp::verify() {
765 auto iterspace_rank = getStaticTileSizes().size();
766 ArrayRef<int64_t> permutation = getStaticTileInterchange();
767 if (permutation.size() > iterspace_rank)
768 return emitOpError()
769 << "interchange length exceeds iteration space dimensions ("
770 << iterspace_rank << "), found " << getTileInterchange();
771 SmallVector<bool> seen(iterspace_rank, false);
772 for (int64_t v : permutation) {
773 if (!ShapedType::isDynamic(v)) {
774 if (v < 0 || v >= static_cast<int64_t>(iterspace_rank))
775 return emitOpError() << "expects interchange values to be in range [0, "
776 << iterspace_rank << "), found: " << v;
777 if (seen[v])
778 return emitOpError() << "found duplicate interchange value: " << v;
779 seen[v] = true;
780 }
781 }
782
783 ArrayRef<int64_t> sizes = getStaticTileSizes();
784 size_t numExpectedLoops =
785 getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
786 if (numExpectedLoops != getNumResults() - 1)
787 return emitOpError() << "expects " << numExpectedLoops << " loop results";
788
789 return success();
790}
791
792SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
793 return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext());
794}
795
796SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
797 return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
798 getContext());
799}
800
801void transform::FuseOp::getEffects(
803 consumesHandle(getTargetMutable(), effects);
804 onlyReadsHandle(getTileSizesMutable(), effects);
805 onlyReadsHandle(getTileInterchangeMutable(), effects);
806 producesHandle(getOperation()->getOpResults(), effects);
807 modifiesPayload(effects);
808}
809
810//===----------------------------------------------------------------------===//
811// FuseIntoContainingOp
812//===----------------------------------------------------------------------===//
813
814void transform::FuseIntoContainingOp::build(OpBuilder &builder,
816 Value producerOp,
817 Value containingOp) {
818 result.addOperands({producerOp, containingOp});
819 auto resultType = transform::AnyOpType::get(builder.getContext());
820 result.addTypes({resultType, resultType});
821}
822
823/// Add new operands to the forall op for users of the producerOp
824/// that are dominated by the containing scf.forall op.
826 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
827 Operation *containingOp, TilingResult &tileAndFuseResult,
828 int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
830
831 // Count number of users not including the containing op
832 SetVector<Operation *> dominatedUsers;
833 DominanceInfo domInfo(containingOp);
834 for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
835 if (!containingOp->isAncestor(user) &&
836 (domInfo.dominates(containingOp, user))) {
837 dominatedUsers.insert(user);
838 }
839 }
840 if (dominatedUsers.empty())
841 return nullptr;
842
843 // Create new scf.forall op
844 auto forallOp = cast<scf::ForallOp>(containingOp);
845 OpBuilder::InsertionGuard g(rewriter);
846 rewriter.setInsertionPoint(forallOp);
847
848 // Get new output
849 Location loc = forallOp.getLoc();
850 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
851 if (!genericOp)
852 return nullptr;
853 SmallVector<Value> outputs = genericOp.getOutputs();
854 SmallVector<Value> newOuts(forallOp.getOutputs());
855 newOuts.push_back(outputs[resultNumber]);
856
857 // Create new scf.forall op
858 auto newforallOp = scf::ForallOp::create(
859 rewriter, loc, forallOp.getMixedLowerBound(),
860 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
861 forallOp.getMapping());
862 rewriter.eraseBlock(newforallOp.getBody());
863 newforallOp.getRegion().takeBody(forallOp.getRegion());
864
865 // Add additional block argument for new value being returned
866 // and replaces all uses of the new output with corresponding bbArg
867 // inside the scf.forall to enable fusion into this new scf.forall.
868 newforallOp.getBody()->addArgument(newOuts.back().getType(),
869 newOuts.back().getLoc());
870 auto bbArgs = newforallOp.getBody()->getArguments();
871 rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
872 [&](OpOperand &use) {
873 Operation *op = use.getOwner();
874 return newforallOp->isProperAncestor(op);
875 });
876
877 // Fix terminator
878 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
879 SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
880 terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
881 Operation *firstYieldOp = yieldingOps.front();
882 rewriter.setInsertionPoint(firstYieldOp);
883 Value src = tileAndFuseResult.tiledValues[0];
884 Value dst = newforallOp.getRegionIterArgs().back();
885 SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
886 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->getLoc(), src,
887 dst, offsets, sizes, strides);
888
889 for (auto result : llvm::enumerate(forallOp.getResults())) {
890 rewriter.replaceAllUsesWith(result.value(),
891 newforallOp->getResult(result.index()));
892 }
893 rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
894 newforallOp->getResults().back(),
895 [&](OpOperand &use) {
896 Operation *user = use.getOwner();
897 return dominatedUsers.contains(user);
898 });
899 return newforallOp;
900}
901
902/// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
903/// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
904/// outer loop. To determine the second condition, this function iterates
905/// using a worklist over the enclosing loops, trying to find 'src' in any of
906/// the parent loop's iter args.
907static bool sameOrEquivalentIterArg(Value src, Value dst) {
908 // Stack like vector containing possible iterArgs candidates. The first one
909 // is dst, and we will transverse the IR from there.
910 SmallVector<Value> destWorklist;
911 destWorklist.push_back(dst);
912
913 while (!destWorklist.empty()) {
914 Value currentDst = destWorklist.pop_back_val();
915
916 // We have found the same operand in some iter arg in the loop structure,
917 // so src and dst are equivalent.
918 if (src == currentDst)
919 return true;
920
921 // The operands are not equivalent, look for enclosing loops over
922 // currentDst.
923 auto bbArg = dyn_cast<BlockArgument>(currentDst);
924 if (!bbArg)
925 continue;
926
927 Block *parentBlock = bbArg.getOwner();
928 assert(parentBlock && "unlinked block argument");
929
930 Operation *parentOp = parentBlock->getParentOp();
931 assert(parentOp && "expected block argument with parent operation");
932
933 // Check if parent is loop-like. If it's not, do not add it to the worklist.
934 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
935 if (!parentLoop)
936 continue;
937
938 for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
939 // No need to check for null as innerIterArg is tied to parentLoop.
940 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
941 Value loopBlockArgument =
942 parentLoop->getOperand(operand->getOperandNumber());
943 destWorklist.push_back(loopBlockArgument);
944 }
945 }
946
947 return false;
948}
949
950/// Find the first "extract" user of `producerOp` and tile it right before its
951/// use. The tiled op is fused under the `containingOp`.
952/// Return this fused op on success or nullptr if anything fails.
953/// If tiled op has uses that are dominated by `containingOp`, return
954/// a new `containingOp` with results of the fused op appended to
955/// results of the `containingOp` or nullptr if there are no dominated uses.
956static std::tuple<SmallVector<Operation *>, Operation *>
958 Operation *producerOp, Operation *containingOp) {
959 LDBG() << "Try to fuse a direct extract use";
960 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
961 if (!tileableProducer) {
962 diag.attachNote(producerOp->getLoc())
963 << "producer is not a TileableInterface: " << *producerOp;
964 return {};
965 }
966
967 // Search the producer slices accessed within the containing operation.
968 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
969 // evolve into an interface.
970 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
971 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
972 return sliceOp && containingOp->isProperAncestor(sliceOp);
973 });
974
975 // Find a fusion opportunity.
976 if (it == tileableProducer->getUsers().end()) {
977 diag.attachNote(tileableProducer->getLoc())
978 << "could not find fusion opportunity for: " << *tileableProducer;
979 return {};
980 }
981 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
982
983 // Try to fuse the producer in-place.
984 OpBuilder::InsertionGuard guard(rewriter);
985 rewriter.setInsertionPoint(sliceOpToTile);
986
987 // Clone the producer inside the consumer and try to update the producer init
988 // operands using the loop bbArgs if applicable. More precisely, if the bbArg
989 // of the container loop points to a value that it is used by the consumer op,
990 // then, instead of using such value on the consumer, use the value coming
991 // from the bbArg instead. This allows to reuse the output tensor (instead of
992 // creating a new one) of the container when both producer and container write
993 // to the same output.
994 if (LoopLikeOpInterface containerLoop =
995 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
996 Operation *clone = rewriter.clone(*producerOp);
997 rewriter.modifyOpInPlace(clone, [&]() {
998 // Iterate over the outputs of the producer and over the loop bbArgs and
999 // check if any bbArg points to the same value as the producer output. In
1000 // such case, make the producer output point to the bbArg directly.
1001 auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone);
1002 if (!dpsInterface)
1003 return;
1004
1005 for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
1006 Value producerOperand =
1007 clone->getOperand(initOperandPtr.getOperandNumber());
1008 for (BlockArgument containerIterArg :
1009 containerLoop.getRegionIterArgs()) {
1010 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
1011 Value consumerOperand =
1012 containerLoop->getOperand(bbArg->getOperandNumber());
1013 // The producer has the same init as the loop bbArg, use it.
1014 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
1015 initOperandPtr.set(containerIterArg);
1016 }
1017 }
1018 }
1019 });
1020
1021 tileableProducer = dyn_cast<TilingInterface>(clone);
1022 }
1023
1024 // Tile the producer.
1025 int64_t resultNumber =
1026 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
1027 LDBG() << "resultNumber: " << resultNumber;
1028
1029 SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
1030 SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
1031
1032 FailureOr<TilingResult> tileAndFuseResult =
1033 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
1034 sizes);
1035
1036 if (failed(tileAndFuseResult)) {
1037 diag.attachNote(tileableProducer->getLoc())
1038 << "failed to tile producer op: " << *tileableProducer;
1039 return {};
1040 }
1041
1042#ifndef NDEBUG
1043 for (auto *tiledOp : tileAndFuseResult->tiledOps) {
1044 LDBG() << "tiledProducer: " << *tiledOp;
1045 }
1046#endif
1047
1048 // Replace the extract op.
1049 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1050 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1051 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1052 if (failed(maybeRankReduced)) {
1053 diag.attachNote(producerOp->getLoc())
1054 << "shape types don't match (missing canonicalization?):\nTiledOp: "
1055 << tileAndFuseResult->tiledValues[0]
1056 << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
1057 return {};
1058 }
1059 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
1060
1061 // Add new outputs to containing op, if required
1062 Operation *newContainingOp = replaceForAllWithNewSignature(
1063 rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
1064 resultNumber, offsets, sizes);
1065
1066 // Cleanup clone.
1067 if (isa<LoopLikeOpInterface>(containingOp))
1068 rewriter.eraseOp(tileableProducer);
1069
1070 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
1071}
1072
1073/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
1074/// it is exactly the `containingOp`, otherwise bail.
1075/// Then, find the first "extract" user of the tied block argument and tile it
1076/// right before its "extract" use. The tiled op is fused under the
1077/// `containingOp`.
1078/// Return this fused op on success or nullptr if anything fails.
1081 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
1082 Operation *containingOp) {
1083 LDBG() << "Try to fuse an extract use through block argument";
1084
1085 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
1086 if (!tileableProducer) {
1087 diag.attachNote(producerOp->getLoc())
1088 << "producer is not a TileableInterface: " << *producerOp;
1089 return {};
1090 }
1091
1092 // Search the first use by a "scf::ForallOp" user.
1093 scf::ForallOp forallOp;
1094 auto itProducerUses =
1095 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
1096 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
1097 return forallOp;
1098 });
1099 // If it's not from the containing op, return.
1100 if (!forallOp || forallOp != containingOp) {
1101 diag.attachNote(tileableProducer->getLoc())
1102 << "could not find a use by the containing op: " << *tileableProducer;
1103 return {};
1104 }
1105
1106 // Search the producer slices accessed within the containing
1107 // operation.
1108 // TODO: Generalize to more extract/insert/parallel_insert triples.
1109 // Maybe evolve into an interface.
1110 OpOperand *pUse = &(*itProducerUses);
1111 BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
1112
1113 // Search the producer slices accessed within the containing operation.
1114 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
1115 // evolve into an interface.
1116 auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
1117 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1118 return sliceOp && containingOp->isProperAncestor(sliceOp);
1119 });
1120
1121 // Find a fusion opportunity.
1122 if (itBBArgUsers == bbArg.getUsers().end()) {
1123 diag.attachNote(containingOp->getLoc())
1124 << "could not find fusion opportunity for bbArg: " << bbArg;
1125 return {};
1126 }
1127 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
1128
1129 // Try to fuse the producer in-place.
1130 OpBuilder::InsertionGuard guard(rewriter);
1131 rewriter.setInsertionPoint(sliceOpToTile);
1132
1133 // Replace the use in the tileableProducer before tiling: clone, replace and
1134 // then tile.
1135 int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
1136 LDBG() << "resultNumber: " << resultNumber;
1137
1138 // Gather destination tensors.
1139 SmallVector<Value> destinationTensors;
1141 rewriter, tileableProducer->getLoc(), tileableProducer,
1142 destinationTensors))) {
1143 diag.attachNote(tileableProducer->getLoc())
1144 << "failed to get destination tensors for: " << *tileableProducer;
1145 return {};
1146 }
1147
1148 IRMapping bvm;
1149 bvm.map(destinationTensors[resultNumber], bbArg);
1150 auto tileableProducerClone =
1151 cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
1152 auto scopeGuard =
1153 llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
1154
1155 // Tile the producer.
1156 FailureOr<TilingResult> tileAndFuseResult =
1157 tileableProducerClone.generateResultTileValue(
1158 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
1159 sliceOpToTile.getMixedSizes());
1160 if (failed(tileAndFuseResult)) {
1161 diag.attachNote(tileableProducer->getLoc())
1162 << "failed to tile producer op: " << *tileableProducer;
1163 return {};
1164 }
1165
1166 // Replace the extract op.
1167 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1168 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1169 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1170 assert(succeeded(maybeRankReduced) && "unexpected shape");
1171 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
1172
1173 // Replace the use in containingOp.
1174 rewriter.modifyOpInPlace(containingOp, [&]() {
1175 containingOp->setOperand(pUse->getOperandNumber(),
1176 destinationTensors.front());
1177 });
1178
1179 return tileAndFuseResult->tiledOps;
1180}
1181
1183 Operation *producerOp,
1184 Operation *containingOp) {
1185 LDBG() << "Try to fuse an use by cloning";
1186
1187 // Gather all uses inside the containing op.
1189 for (OpResult result : producerOp->getOpResults()) {
1190 for (OpOperand &use : result.getUses()) {
1191 if (containingOp->isProperAncestor(use.getOwner())) {
1192 uses.push_back(&use);
1193 continue;
1194 }
1195 // Cannot clone and fuse if the use is by the containing op itself: fail
1196 // immediately.
1197 if (containingOp == use.getOwner()) {
1198 diag.attachNote(producerOp->getLoc())
1199 << "producer op use by containing op cannot be fused by cloning";
1200 return nullptr;
1201 }
1202 }
1203 }
1204
1205 // Check for a non-empty list of fusion opportunities.
1206 if (uses.empty()) {
1207 diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
1208 return nullptr;
1209 }
1210
1211 // Clone and fuse inside the containing op.
1212 Operation *fusedOp = nullptr;
1213 OpOperand *use = uses.front();
1214 // Parallel insert slice is not a valid clone destination.
1215 // TODO: Generalize to other type of ops.
1216 assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
1217 "Parallel insert slice is not a valid clone destination");
1218 unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
1219 LDBG() << "resultNumber: " << resultNumber;
1220
1221 OpBuilder::InsertionGuard guard(rewriter);
1222 rewriter.setInsertionPoint(use->getOwner());
1223 fusedOp = rewriter.clone(*producerOp);
1224 rewriter.modifyOpInPlace(
1225 use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1226
1227 return fusedOp;
1228}
1229
1230bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1231 // Allow repeated handles since we are fusing everything anyway.
1232 return true;
1233}
1234
1236transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
1239 SmallVector<Operation *> fusedOps;
1240 auto producerOps = state.getPayloadOps(getProducerOp());
1241 auto containingOps = state.getPayloadOps(getContainingOp());
1242 if (!llvm::hasSingleElement(containingOps)) {
1243 return emitDefiniteFailure()
1244 << "requires exactly one containing_op handle (got "
1245 << llvm::range_size(containingOps) << ")";
1246 }
1247 Operation *containingOp = *containingOps.begin();
1248
1249 // If nothing to fuse, propagate success.
1250 if (std::empty(producerOps)) {
1251 results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
1252 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1254 }
1255
1256 // Helper function to find the next producer that should be fused. Take any
1257 // producer that has a use inside the containing op.
1258 SetVector<Operation *> remainingProducers(llvm::from_range, producerOps);
1259 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1260 for (const auto &it : enumerate(remainingProducers)) {
1261 Operation *producerOp = it.value();
1262 // The containing op may be a user of producerOp: use isAncestor.
1263 int64_t numUsesInContainingOp =
1264 llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
1265 return containingOp->isAncestor(op);
1266 });
1267 // TODO: When resolving the TODO below (no duplicate ops), take an op
1268 // that has no use among the remaining producers. This is a topological
1269 // sorting.
1270 if (numUsesInContainingOp > 0) {
1271 if (numUsesInContainingOp == 1)
1272 remainingProducers.erase(remainingProducers.begin() + it.index());
1273 return producerOp;
1274 }
1275 }
1276 return failure();
1277 };
1278
1279 while (!remainingProducers.empty()) {
1280 auto nextProducer = getNextProducer();
1281 if (failed(nextProducer)) {
1282 auto diag = mlir::emitSilenceableFailure(getLoc())
1283 << "could not find next producer to fuse into container";
1284 diag.attachNote(containingOp->getLoc()) << "containing op";
1285 return diag;
1286 }
1287
1288 Operation *producerOp = *nextProducer;
1289
1290 // Default diagnostic, to be complemented with more failure information.
1292 diag << "could not fuse " << *producerOp << " into " << *containingOp;
1293
1294 // TODO: If there are multiple uses of the producer in the containing op,
1295 // we currently tile/clone the op multiple times (once per use). In some
1296 // cases, we can tile/clone once and reuse the value for each use.
1297 // Futhermore, producers should then be traversed according to a
1298 // topological sorting.
1299 auto [tiledOps, newContainingOp] =
1300 tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1301 if (!tiledOps.empty()) {
1302 LDBG() << "\nFused a direct extract use\n" << *containingOp;
1303 fusedOps.append(tiledOps);
1304 if (newContainingOp) {
1305 // Update handles associated with the containing op so we don't need to
1306 // invalidate them. This is a hack to support better composability
1307 // between tiling and fusion while a proper mechanism is being
1308 // investigated.
1309 //
1310 // DO NOT replicate this elsewhere unless you understand what you are
1311 // doing.
1312 LogicalResult replacementStatus =
1313 rewriter.notifyPayloadOperationReplaced(containingOp,
1314 newContainingOp);
1315 (void)replacementStatus;
1316 assert(succeeded(replacementStatus) &&
1317 "unable to update transform state mapping");
1318 rewriter.eraseOp(containingOp);
1319 containingOp = newContainingOp;
1320 }
1321 continue;
1322 }
1323
1324 SmallVector<Operation *> tiledContainingOpOperand =
1326 rewriter, diag, producerOp, containingOp);
1327 if (!tiledContainingOpOperand.empty()) {
1328 LDBG() << "\nFused an extract use through block argument\n"
1329 << *containingOp;
1330 fusedOps.append(tiledContainingOpOperand);
1331 continue;
1332 }
1333
1334 Operation *cloned =
1335 cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1336 if (cloned) {
1337 LDBG() << "\nFused an use by cloning\n" << *containingOp;
1338 fusedOps.push_back(cloned);
1339 continue;
1340 }
1342 }
1343
1344 results.set(cast<OpResult>(getFusedOp()), fusedOps);
1345 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1347}
1348
1349void transform::FuseIntoContainingOp::getEffects(
1351 consumesHandle(getProducerOpMutable(), effects);
1352 onlyReadsHandle(getContainingOpMutable(), effects);
1353 producesHandle(getOperation()->getOpResults(), effects);
1354 modifiesPayload(effects);
1355}
1356
1357//===----------------------------------------------------------------------===//
1358// GeneralizeOp
1359//===----------------------------------------------------------------------===//
1360
1362transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1363 LinalgOp target,
1366 // Exit early if no transformation is needed.
1367 if (isa<GenericOp>(target)) {
1368 results.push_back(target);
1370 }
1371 rewriter.setInsertionPoint(target);
1372 FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1373 if (succeeded(generic)) {
1374 results.push_back(generic->getOperation());
1376 }
1377 return emitDefaultSilenceableFailure(target);
1378}
1379
1380//===----------------------------------------------------------------------===//
1381// SpecializeOp
1382//===----------------------------------------------------------------------===/
1383
1385transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1386 LinalgOp target,
1389 // Exit early if the operation is not a generic.
1390 if (!isa<GenericOp>(target)) {
1391 results.push_back(target);
1393 }
1394 rewriter.setInsertionPoint(target);
1395 FailureOr<LinalgOp> named =
1396 specializeGenericOp(rewriter, cast<GenericOp>(target));
1397 if (succeeded(named)) {
1398 results.push_back(named->getOperation());
1400 }
1401 return emitDefaultSilenceableFailure(target);
1402}
1403
1404//===----------------------------------------------------------------------===//
1405// InterchangeOp
1406//===----------------------------------------------------------------------===//
1407
1409transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1410 GenericOp target,
1413 ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1414 // Exit early if no transformation is needed.
1415 if (interchangeVector.empty()) {
1416 results.push_back(target);
1418 }
1419
1420 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1421 if (interchangeVector.size() != numLoops) {
1422 return emitSilenceableError()
1423 << getIteratorInterchangeAttrName() << " has length ("
1424 << interchangeVector.size()
1425 << ") different from the number of loops in the target operation ("
1426 << numLoops << ")";
1427 }
1428 FailureOr<GenericOp> res = interchangeGenericOp(
1429 rewriter, target, SmallVector<unsigned>(interchangeVector));
1430 if (failed(res))
1431 return emitDefiniteFailure() << "failed to apply";
1432 results.push_back(res->getOperation());
1434}
1435
1436LogicalResult transform::InterchangeOp::verify() {
1437 ArrayRef<int64_t> permutation = getIteratorInterchange();
1438 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1439 if (!std::is_permutation(sequence.begin(), sequence.end(),
1440 permutation.begin(), permutation.end())) {
1441 return emitOpError()
1442 << "expects iterator_interchange to be a permutation, found "
1443 << getIteratorInterchange();
1444 }
1445 return success();
1446}
1447
1448//===----------------------------------------------------------------------===//
1449// LinalgCopyToMemrefOp
1450//===----------------------------------------------------------------------===//
1451
1452DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1453 transform::TransformRewriter &rewriter, Operation *targetOp,
1456
1457 // Check if the target can be converted.
1458 if (!isa<linalg::CopyOp>(targetOp)) {
1460 emitSilenceableError() << "only linalg.copy target ops are supported";
1461 diag.attachNote(targetOp->getLoc()) << "target op";
1462 return diag;
1463 }
1464
1465 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1466 if (!copyOp.hasPureBufferSemantics()) {
1468 emitSilenceableError()
1469 << "cannot transform a linalg.copy on tensors into a memref.copy";
1470 diag.attachNote(targetOp->getLoc()) << "target op";
1471 return diag;
1472 }
1473
1474 SmallVector<Value> inputs = copyOp.getInputs();
1475 SmallVector<Value> outputs = copyOp.getOutputs();
1476 assert(inputs.size() == 1 && "expected linalg copy op with one input");
1477 assert(outputs.size() == 1 && "expected memref copy op with one output");
1478 Value input = inputs.front();
1479 Value output = outputs.front();
1480
1481 // linalg.copy supports different element types on source/dest whereas
1482 // memref.copy does not, so we must check that the source and dest types can
1483 // be handled by memref.copy and otherwise reject the transformation.
1484 if (!isa<ShapedType>(input.getType())) {
1486 emitSilenceableError()
1487 << "cannot transform a linalg.copy which input has no shape";
1488 diag.attachNote(targetOp->getLoc()) << "target op";
1489 return diag;
1490 }
1491
1492 // linalg.copy destination must be a shaped type.
1493 assert(isa<ShapedType>(output.getType()));
1494
1495 if (cast<ShapedType>(input.getType()).getElementType() !=
1496 cast<ShapedType>(output.getType()).getElementType()) {
1498 emitSilenceableError()
1499 << "cannot transform a linalg.copy with different source and "
1500 "destination element types ";
1501 diag.attachNote(targetOp->getLoc()) << "target op";
1502 return diag;
1503 }
1504
1505 // Target can be converted, do it.
1506 auto memrefCopyOp =
1507 rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1508
1509 results.push_back(memrefCopyOp);
1511}
1512
1513//===----------------------------------------------------------------------===//
1514// LowerPackOp
1515//===----------------------------------------------------------------------===//
1516
1517DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1518 transform::TransformRewriter &rewriter, linalg::PackOp target,
1519 transform::ApplyToEachResultList &transformResults,
1521 rewriter.setInsertionPoint(target);
1522 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1523 FailureOr<LowerPackResult> res =
1524 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1525 if (failed(res)) {
1526 return mlir::emitSilenceableFailure(target->getLoc())
1527 << "cannot lower to pad + expand + transpose";
1528 }
1529 transformResults.push_back(res->padOp);
1530 transformResults.push_back(res->expandShapeOp);
1531 transformResults.push_back(res->transposeOp);
1533}
1534
1535//===----------------------------------------------------------------------===//
1536// LowerUnPackOp
1537//===----------------------------------------------------------------------===//
1538
1539DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1540 transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1541 transform::ApplyToEachResultList &transformResults,
1543 rewriter.setInsertionPoint(target);
1544 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1545 FailureOr<LowerUnPackOpResult> res =
1546 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1547 if (failed(res)) {
1549 emitSilenceableError()
1550 << "cannot lower to transpose + collapse + extract";
1551 diag.attachNote(target->getLoc()) << "target payload op";
1552 return diag;
1553 }
1554 transformResults.push_back(res->emptyOp);
1555 transformResults.push_back(res->transposeOp);
1556 transformResults.push_back(res->collapseShapeOp);
1557 transformResults.push_back(res->extractSliceOp);
1559}
1560
1561//===---------------------------------------------------------------------===//
1562// MatchOp
1563//===---------------------------------------------------------------------===//
1564
1565void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1566 Value target, ArrayRef<StringRef> opNames) {
1567 result.addOperands(target);
1568 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1569 builder.getStrArrayAttr(opNames));
1570 result.addTypes(transform::AnyOpType::get(builder.getContext()));
1571}
1572
1573void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1574 TypeRange resultTypes, Value target,
1575 ArrayRef<StringRef> opNames) {
1576 result.addOperands(target);
1577 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1578 builder.getStrArrayAttr(opNames));
1579 result.addTypes(resultTypes);
1580}
1581
1583transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1586 llvm::StringSet<> strs;
1587 if (getOps().has_value())
1588 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1589
1590 auto payloadOps = state.getPayloadOps(getTarget());
1591 if (!llvm::hasSingleElement(payloadOps)) {
1592 return emitDefiniteFailure("requires exactly one target handle");
1593 }
1594
1596 bool incorrectNumOperandTypes = false;
1597 auto matchFun = [&](Operation *op) {
1598 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1599 return;
1600
1601 // Interfaces cannot be matched by name, just by ID.
1602 // So we specifically encode the interfaces we care about for this op.
1603 if (getInterface().has_value()) {
1604 auto iface = getInterface().value();
1605 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1606 !isa<LinalgOp>(op))
1607 return;
1608 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1609 !isa<TilingInterface>(op))
1610 return;
1611 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1612 !isa<LoopLikeOpInterface>(op))
1613 return;
1614 }
1615
1616 // Check if all specified attributes match.
1617 if (getOpAttrs().has_value()) {
1618 DictionaryAttr opAttrs = getOpAttrs().value();
1619 for (NamedAttribute attr : opAttrs) {
1620 if (attr.getName() == getInterfaceAttrName() ||
1621 attr.getName() == getOpsAttrName())
1622 continue;
1623 if (!op->hasAttr(attr.getName()))
1624 return;
1625 if (op->getAttr(attr.getName()) != attr.getValue())
1626 return;
1627 }
1628 }
1629
1630 if (getFilterResultType().has_value()) {
1631 Type t = getFilterResultType().value();
1632 if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1633 return;
1634 }
1635
1636 if (getFilterOperandTypes().has_value()) {
1637 mlir::ArrayAttr types = getFilterOperandTypes().value();
1638 auto operandTypes = op->getOperandTypes();
1639
1640 if (types.size() == 1) {
1641 // All the operands must must be equal to the specified type
1642 auto typeattr =
1643 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1644 Type t = cast<::mlir::Type>(typeattr.getValue());
1645 if (!llvm::all_of(op->getOperandTypes(),
1646 [&](Type operandType) { return operandType == t; }))
1647 return;
1648 } else {
1649 // The operand types must match all the types in the list (in the same
1650 // order in with they are specified)
1651 if (types.size() != operandTypes.size()) {
1652 incorrectNumOperandTypes = true;
1653 return;
1654 }
1655
1656 for (auto [attr, operandType] :
1657 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1658 auto typeattr = cast<mlir::TypeAttr>(attr);
1659 Type type = cast<::mlir::Type>(typeattr.getValue());
1660
1661 if (type != operandType)
1662 return;
1663 }
1664 }
1665 }
1666
1667 // All constraints are satisfied.
1668 res.push_back(op);
1669 return;
1670 };
1671
1672 (*payloadOps.begin())->walk(matchFun);
1673 if (incorrectNumOperandTypes)
1674 return emitDefiniteFailure("If filter_operand_types contains more than a "
1675 "type, then it must contain as much types as "
1676 "the number of operands in the target ops");
1677 results.set(cast<OpResult>(getResult()), res);
1679}
1680
1681//===---------------------------------------------------------------------===//
1682// MultiTileSizesOp
1683//===---------------------------------------------------------------------===//
1684
1686 Type targetType, Type lowSizeType, Type,
1687 Type) {
1688 printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1689}
1690
1691static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1692 Type &targetType, Type &lowSizeType,
1693 Type &highSizeType,
1694 Type &splitPointType) {
1695 FunctionType funcType;
1696 llvm::SMLoc typeLoc = parser.getCurrentLocation();
1697 if (failed(parser.parseType<FunctionType>(funcType)))
1698 return failure();
1699
1700 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1701 parser.emitError(typeLoc) << "expects a trailing functional type with one "
1702 "argument and one result";
1703 }
1704 targetType = funcType.getInput(0);
1705 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1706
1707 return success();
1708}
1709
1710DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1711 transform::TransformRewriter &rewriter, LinalgOp target,
1713 if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1714 if (target.hasDynamicShape()) {
1715 auto diag = emitSilenceableError()
1716 << "cannot compute parametric tile sizes for dynamically "
1717 "shaped payload op";
1718 diag.attachNote(target->getLoc()) << "payload op";
1719 return diag;
1720 }
1721
1722 FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1723 target, getDimension(), getTargetSize(), getDivisor());
1724 if (failed(spec)) {
1725 return emitSilenceableError()
1726 << "failed to compute multi-size tiling sizes";
1727 }
1728
1729 Builder builder(target.getContext());
1730 results.assign(llvm::map_range(
1731 ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1732 spec->lowTileSize * spec->lowTripCount}),
1733 [&builder, this](int64_t value) {
1734 return builder.getIntegerAttr(
1735 cast<ParamType>(getLowSize().getType()).getType(), value);
1736 }));
1738 }
1739
1740 OpBuilder builder(target.getContext());
1741 builder.setInsertionPoint(target);
1742 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1743 OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1744 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1745 builder, target, getDimension(), targetSize, divisor);
1746 if (failed(spec)) {
1747 return emitSilenceableError() << "could not generate tile size computation";
1748 }
1749
1750 AffineExpr s0 = builder.getAffineSymbolExpr(0);
1751 AffineExpr s1 = builder.getAffineSymbolExpr(1);
1752 Operation *splitPoint =
1753 affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1754 {spec->lowTileSize, spec->lowTripCount});
1755 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1756 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1757 assert(lowTileSize && highTileSize && splitPoint &&
1758 "tile sizes are not produced by operations");
1759 results.reserve(results.size() + 3);
1760 results.push_back(lowTileSize);
1761 results.push_back(highTileSize);
1762 results.push_back(splitPoint);
1764}
1765
1766void transform::MultiTileSizesOp::getEffects(
1768 onlyReadsHandle(getTargetMutable(), effects);
1769 producesHandle(getOperation()->getOpResults(), effects);
1770 if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1771 onlyReadsPayload(effects);
1772 else
1773 modifiesPayload(effects);
1774}
1775
1776LogicalResult transform::MultiTileSizesOp::verify() {
1777 if (getLowSize().getType() != getHighSize().getType() ||
1778 getLowSize().getType() != getSplitPoint().getType()) {
1779 return emitOpError() << "expects all results type to be the same";
1780 }
1781 return success();
1782}
1783
1784//===---------------------------------------------------------------------===//
1785// PackOp
1786//===---------------------------------------------------------------------===//
1787
1788void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1789 Value target,
1790 ArrayRef<OpFoldResult> mixedPackedSizes) {
1791 SmallVector<int64_t> staticPackedSizes;
1792 SmallVector<Value> dynamicPackedSizes;
1793 dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1794 staticPackedSizes);
1795 // Call the default builder which sets up the proper operands segment sizes
1796 // attributes for multiple variadic operands. In the absence of this, horrible
1797 // bugs ensue.
1798 Type linalgOpHType = transform::OperationType::get(
1799 builder.getContext(), GenericOp::getOperationName());
1800 build(builder, result,
1801 /*resultType=*/linalgOpHType,
1802 /*target=*/target,
1803 /*dynamic_sizes=*/dynamicPackedSizes,
1804 /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1805}
1806
1807SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1808 Builder b(getContext());
1809 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1810}
1811
1813transform::PackOp::apply(transform::TransformRewriter &rewriter,
1814 transform::TransformResults &transformResults,
1816 auto targetOps = state.getPayloadOps(getTarget());
1817 // If nothing to pack, propagate success.
1818 if (std::empty(targetOps)) {
1819 transformResults.set(cast<OpResult>(getPackedOp()),
1822 }
1823 // Fail on multi-op handles.
1824 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1825 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1826 return emitSilenceableError()
1827 << "requires target to map to exactly 1 LinalgOp (got "
1828 << llvm::range_size(targetOps) << ")";
1829 }
1830 // Fail on mismatched number of pack sizes.
1831 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1832 return emitSilenceableError()
1833 << "requires number of packed sizes match the number of loops ("
1834 << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1835 << ")";
1836 }
1837
1838 // Unpack handles to constants or actual SSA index values.
1839 SmallVector<OpFoldResult> packedSizes;
1841 state, *this, packedSizes, getMixedPackedSizes());
1842
1843 rewriter.setInsertionPoint(linalgOp);
1844 FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1845 if (failed(maybeResult))
1846 return emitDefiniteFailure("data tiling failed");
1847
1848 transformResults.set(cast<OpResult>(getPackedOp()),
1849 {maybeResult->packedLinalgOp.getOperation()});
1851}
1852
1853void transform::PackOp::getEffects(
1855 transform::consumesHandle(getTargetMutable(), effects);
1856 transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1857 transform::producesHandle(getOperation()->getOpResults(), effects);
1859}
1860
1861//===---------------------------------------------------------------------===//
1862// PackGreedilyOp.
1863//===---------------------------------------------------------------------===//
1864
1865LogicalResult transform::PackGreedilyOp::verify() {
1866 if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1867 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1868 << " is not a valid permutation";
1869 }
1870 // TODO: relax to allow empty once we have another strategy than just matmul.
1871 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1872 for (auto [s, nmo] :
1873 llvm::zip_equal(getMixedMatmulPackedSizes(),
1874 getMatmulPaddedSizesNextMultipleOf())) {
1875 std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1876 if (nmo != 0 &&
1877 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1878 return emitOpError() << "at most one of the packed_size and the "
1879 "padded_sizes_next_multiple_of can be nonzero "
1880 "for the matmul strategy";
1881 }
1882 }
1883 }
1884 return success();
1885}
1886
1888PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1889 transform::TransformResults &transformResults,
1892 for (Operation *op : state.getPayloadOps(getTarget())) {
1893 auto linalgOp = dyn_cast<LinalgOp>(op);
1894 if (!linalgOp)
1895 continue;
1896 // linalgOp will be replaced and the insertion point may be invalidated if
1897 // we set it before -> set it after.
1898 rewriter.setInsertionPointAfter(linalgOp);
1899 // Failing to pack greedily is perfectly fine.
1900 // In the future we will want to order packings according to some metric.
1901 FailureOr<PackResult> packResult = packMatmulGreedily(
1902 /*rewriter=*/rewriter,
1903 /*linalgOp=*/linalgOp,
1904 /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1905 /*mnkPaddedSizesNextMultipleOf=*/
1906 getMatmulPaddedSizesNextMultipleOf(),
1907 /*mnkOrder=*/getMatmulInnerDimsOrder());
1908 if (succeeded(packResult)) {
1909 results.push_back(packResult->packedLinalgOp);
1910 continue;
1911 }
1912 results.push_back(linalgOp);
1913 }
1914 transformResults.set(cast<OpResult>(getPackedOp()), results);
1916}
1917
1918SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1919 Builder b(getContext());
1920 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1921 b);
1922}
1923
1924void transform::PackGreedilyOp::getEffects(
1926 transform::consumesHandle(getTargetMutable(), effects);
1927 transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1928 transform::producesHandle(getOperation()->getOpResults(), effects);
1930}
1931
1932//===---------------------------------------------------------------------===//
1933// PackTransposeOp
1934//===---------------------------------------------------------------------===//
1935
1936LogicalResult transform::PackTransposeOp::verify() {
1937 if (!isPermutationVector(getInnerPerm())) {
1938 return emitOpError() << getInnerPermAttrName()
1939 << " is not a valid permutation";
1940 }
1941 if (!isPermutationVector(getOuterPerm())) {
1942 return emitOpError() << getOuterPermAttrName()
1943 << " is not a valid permutation";
1944 }
1945 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1946 return emitOpError() << " at least one of " << getInnerPermAttrName()
1947 << " or " << getOuterPermAttrName()
1948 << " must be specified";
1949 }
1950 return success();
1951}
1952
1953namespace {
1954enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1955} // namespace
1956
1957/// Return true if `permutation` is a valid permutation of the
1958/// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1959/// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1960/// This is the case when the `permutation` rank matches the rank expected by
1961/// `op` and `permutation` is itself a permutation vector.
1962/// Return true if either `op` or `permutation` are empty to allow a simpler
1963/// polymorphic implementation.
1964template <typename RelayoutOpTy>
1965static bool isValidPackingPermutation(
1966 RelayoutOpTy op, ArrayRef<int64_t> permutation,
1967 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1968 static_assert(
1969 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1970 "applies to only pack or unpack operations");
1971 if (!op || permutation.empty())
1972 return true;
1973 size_t innerRank = op.getInnerDimsPos().size();
1974 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1975 return permutation.size() == innerRank && isPermutationVector(permutation);
1976 // op.getOuterDimsPerm() may be empty, in which case it is identity.
1977 // Don't rely on it.
1978 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1979 return permutation.size() == op.getSourceRank() &&
1980 isPermutationVector(permutation);
1981 }
1982 return permutation.size() == op.getDestRank() &&
1983 isPermutationVector(permutation);
1984}
1985
1987transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1988 transform::TransformResults &transformResults,
1990 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1991 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1992 // Step 1. If nothing to pack, propagate success.
1993 if (std::empty(packOrUnpackOps)) {
1994 transformResults.set(cast<OpResult>(getPackedOp()), {});
1995 transformResults.set(cast<OpResult>(getPackOp()), {});
1996 transformResults.set(cast<OpResult>(getUnPackOp()), {});
1998 }
1999
2000 // Step 2. Bunch of runtime sanity check and error messages.
2001 // Step 2.1. Fail on multi-op handles.
2002 if (!llvm::hasSingleElement(packOrUnpackOps) ||
2003 !llvm::hasSingleElement(linalgOps)) {
2004 return emitSilenceableError()
2005 << "requires target to map to exactly 1 "
2006 "packing op and 1 packed op ("
2007 << "got " << llvm::range_size(packOrUnpackOps) << " and "
2008 << llvm::range_size(linalgOps) << ")";
2009 }
2010
2011 // Step 2.2. Fail on wrong type.
2012 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
2013 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
2014 if ((!packOp && !unPackOp)) {
2015 return emitSilenceableError() << "requires target to map to a "
2016 "linalg.pack or linalg.unpack";
2017 }
2018 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
2019 if (!linalgOpTarget)
2020 return emitSilenceableError() << "requires a LinalgOp target";
2021
2022 // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
2023 LinalgOp linalgOp;
2024 if (packOp && packOp.getResult().hasOneUse())
2025 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
2026 else if (unPackOp)
2027 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
2028 if (linalgOp != linalgOpTarget) {
2029 auto errorMsg =
2030 packOp ? StringLiteral{"not a single use by the LinalgOp target"}
2031 : StringLiteral{"not produced by the LinalgOp target"};
2032 return emitSilenceableError() << errorMsg;
2033 }
2034
2035 // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
2036 // PackOp.
2037 if (unPackOp) {
2038 assert(!packOp && "packOp must be null on entry when unPackOp is not null");
2039 OpOperand *packUse = linalgOp.getDpsInitOperand(
2040 cast<OpResult>(unPackOp.getSource()).getResultNumber());
2041 packOp = packUse->get().getDefiningOp<linalg::PackOp>();
2042 if (!packOp || !packOp.getResult().hasOneUse())
2043 return emitSilenceableError() << "could not find matching pack op";
2044 }
2045
2046 // Step 2.5. Fail if any permutation does not validate.
2047 for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
2048 ArrayRef<int64_t> perm =
2049 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
2050 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
2051 ? StringLiteral{"invalid outer_perm"}
2052 : StringLiteral{"invalid inner_perm"};
2053 if (!isValidPackingPermutation(packOp, perm, permType) ||
2054 !isValidPackingPermutation(unPackOp, perm, permType)) {
2055 Operation *packOrUnpackOp =
2056 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
2057 return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
2058 }
2059 }
2060
2061 // From here on, packOp and linalgOp are always present, unPackOp may or may
2062 // not be present.
2063 assert(packOp && linalgOp && "unexpected null op");
2064
2065 // Step 3. Actually transpose the ops.
2066 FailureOr<PackTransposeResult> res = packTranspose(
2067 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
2068 // Preconditions have been checked, it is an error to fail here.
2069 assert(succeeded(res) && "unexpected packTranspose failure");
2070
2071 // Step 4. Return results.
2072 transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
2073 transformResults.set(cast<OpResult>(getPackedOp()),
2074 {res->transposedLinalgOp});
2075 if (unPackOp) {
2076 transformResults.set(cast<OpResult>(getUnPackOp()),
2077 {res->transposedUnPackOp});
2078 } else {
2079 transformResults.set(cast<OpResult>(getUnPackOp()), {});
2080 }
2081
2083}
2084
2085//===---------------------------------------------------------------------===//
2086// PadOp
2087//===---------------------------------------------------------------------===//
2088
2089void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
2090 ArrayRef<int64_t> paddingDimensions,
2091 ArrayRef<int64_t> padToMultipleOf,
2092 ArrayRef<int64_t> nofoldFlags,
2093 ArrayRef<Attribute> transposePaddings,
2094 StringRef copyBackOp,
2095 bool usePrescribedTensorShapes) {
2096 auto resultType = transform::AnyOpType::get(b.getContext());
2097 return build(/*odsBuilder=*/b,
2098 /*result=*/result,
2099 /*types=*/TypeRange{resultType, resultType},
2100 /*target=*/target,
2101 /*padding_values=*/ArrayAttr(), // let inference handle this
2102 /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
2103 /*pad_to_multiple_of=*/ValueRange{},
2104 /*padToMultipleOf=*/
2105 (padToMultipleOf.empty()
2107 : b.getDenseI64ArrayAttr(padToMultipleOf)),
2108 /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
2109 /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
2110 /*copy_back_op=*/b.getStringAttr(copyBackOp),
2111 /*use_prescribed_tensor_shapes=*/
2112 usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
2113}
2114
2115void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
2116 ArrayRef<int64_t> paddingDimensions,
2117 ArrayRef<OpFoldResult> mixedPadToMultipleOf,
2118 ArrayRef<int64_t> nofoldFlags,
2119 ArrayRef<Attribute> transposePaddings,
2120 StringRef copyBackOp,
2121 bool usePrescribedTensorShapes) {
2122 auto resultType = transform::AnyOpType::get(b.getContext());
2123 SmallVector<int64_t> staticPadToMultipleOf;
2124 SmallVector<Value> dynamicPadToMultipleOf;
2125 dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
2126 staticPadToMultipleOf);
2127 return build(/*odsBuilder=*/b,
2128 /*result=*/result,
2129 /*types=*/TypeRange{resultType, resultType},
2130 /*target=*/target,
2131 /*padding_values=*/ArrayAttr(), // let inference handle this
2132 /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
2133 /*pad_to_multiple_of=*/dynamicPadToMultipleOf,
2134 /*padToMultipleOf=*/staticPadToMultipleOf,
2135 /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
2136 /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
2137 /*copy_back_op=*/copyBackOp,
2138 /*use_prescribed_tensor_shapes=*/usePrescribedTensorShapes);
2139}
2140
2141void PadOp::getEffects(
2143 consumesHandle(getTargetMutable(), effects);
2144 onlyReadsHandle(getPadToMultipleOfMutable(), effects);
2145 producesHandle(getOperation()->getOpResults(), effects);
2146 modifiesPayload(effects);
2147}
2148
2149SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
2150 Builder b(getContext());
2151 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
2152}
2153
2154DiagnosedSilenceableFailure
2155transform::PadOp::apply(transform::TransformRewriter &rewriter,
2156 transform::TransformResults &results,
2157 transform::TransformState &state) {
2158 auto transformOp = cast<TransformOpInterface>(getOperation());
2159 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
2160
2161 for (Operation *target : state.getPayloadOps(getTarget())) {
2162 auto linalgTarget = dyn_cast<LinalgOp>(target);
2163 if (!linalgTarget) {
2164 auto diag = emitSilenceableError() << "expected LinalgOp target";
2165 diag.attachNote(target->getLoc()) << "target op";
2166 return diag;
2167 }
2168
2169 // Convert the integer packing flags to booleans.
2170 SmallVector<bool> nofoldFlags;
2171 for (int64_t packPadding :
2172 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
2173 nofoldFlags.push_back(static_cast<bool>(packPadding));
2174
2175 // Convert the padding values to attributes.
2176 SmallVector<Attribute> paddingValues;
2177 for (auto const &[untypedAttr, elementOrTensorType] :
2178 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
2179
2180 if (isa<ub::PoisonAttr>(untypedAttr)) {
2181 paddingValues.push_back(untypedAttr);
2182 continue;
2183 }
2184 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2185 if (!attr) {
2186 emitOpError("expects padding values to be typed attributes or poison");
2188 }
2189 Type elementType = getElementTypeOrSelf(elementOrTensorType);
2190 // Try to parse string attributes to obtain an attribute of element type.
2191 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2192 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2193 stringAttr, getContext(), elementType,
2194 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2195 if (!parsedAttr || parsedAttr.getType() != elementType) {
2196 auto diag = this->emitOpError("expects a padding that parses to ")
2197 << elementType << ", got " << untypedAttr;
2198 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2200 }
2201 paddingValues.push_back(parsedAttr);
2202 continue;
2203 }
2204 // Otherwise, add the attribute directly.
2205 if (attr.getType() != elementType) {
2206 auto diag = this->emitOpError("expects a padding value of type ")
2207 << elementType << ", got " << attr;
2208 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2210 }
2211 paddingValues.push_back(attr);
2212 }
2213
2214 // Extract the transpose vectors.
2215 SmallVector<SmallVector<int64_t>> transposePaddings;
2216 for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2217 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2218 cast<ArrayAttr>(transposeVector)));
2219
2220 LinalgOp paddedOp;
2221 LinalgPaddingOptions options;
2222 options.paddingDimensions =
2223 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2224
2225 SmallVector<int64_t> padToMultipleOf;
2226 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
2227 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2228 if (!status.succeeded())
2229 return status;
2230 if (padToMultipleOf.empty())
2231 padToMultipleOf =
2232 SmallVector<int64_t>(options.paddingDimensions.size(), 1);
2233
2234 options.padToMultipleOf = padToMultipleOf;
2235 options.paddingValues = paddingValues;
2236 options.nofoldFlags = nofoldFlags;
2237 if (getCopyBackOp() ==
2238 bufferization::MaterializeInDestinationOp::getOperationName()) {
2239 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
2240 BufferizationMaterializeInDestination;
2241 } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2242 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
2243 } else if (getCopyBackOp() == kCopyOpNone) {
2244 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
2245 } else {
2246 llvm_unreachable("unsupported copy_back op");
2247 }
2248 // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2249 bool irChanged = false;
2250 if (getUsePrescribedTensorShapes() &&
2251 linalgTarget.hasPureTensorSemantics()) {
2252 OpBuilder::InsertionGuard g(rewriter);
2253 rewriter.setInsertionPoint(linalgTarget);
2254 for (OpOperand &operand : linalgTarget->getOpOperands()) {
2255 for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2256 if (ShapedType::isStatic(dim))
2257 continue;
2258 options.setSizeToPadTo(operand.getOperandNumber(), i,
2259 tensor::getMixedSize(rewriter,
2260 operand.get().getLoc(),
2261 operand.get(), i));
2262 irChanged = true;
2263 }
2264 }
2265 }
2266
2267 SmallVector<Value> replacements;
2268 SmallVector<tensor::PadOp> newPadOps;
2269 if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
2270 replacements, newPadOps))) {
2271 if (irChanged) {
2272 auto diag = emitDefiniteFailure() << "failed to pad op";
2273 diag.attachNote(target->getLoc()) << "target op";
2274 return diag;
2275 }
2276 auto diag = emitSilenceableError() << "failed to pad op";
2277 diag.attachNote(target->getLoc()) << "target op";
2278 return diag;
2279 }
2280
2281 // We need to perform our own replacement here because this API is still
2282 // used in patterns that "pad and hoist", for which the replacement values
2283 // need to be different.
2284 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2285 // that we have more composable abstractions.
2286 rewriter.replaceOp(linalgTarget, replacements);
2287 paddedOps.push_back(paddedOp);
2288 padOps.append(newPadOps.begin(), newPadOps.end());
2289 if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2290 for (Value v : replacements) {
2291 Operation *copyBackOp = v.getDefiningOp();
2292 if (!llvm::is_contained(copyBackOps, copyBackOp))
2293 copyBackOps.push_back(copyBackOp);
2294 }
2295 }
2296 }
2297
2298 results.set(cast<OpResult>(getPadded()), paddedOps);
2299 results.set(cast<OpResult>(getPad()), padOps);
2300 results.set(cast<OpResult>(getCopy()), copyBackOps);
2302}
2303
2304LogicalResult transform::PadOp::verify() {
2305 SmallVector<int64_t> nofoldFlags =
2306 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2307 if (any_of(nofoldFlags, [](int64_t packPadding) {
2308 return packPadding != 0 && packPadding != 1;
2309 })) {
2310 return emitOpError()
2311 << "expects nofold_flags to contain booleans (0/1), found "
2312 << getNofoldFlags();
2313 }
2314
2315 SmallVector<int64_t> paddingDimensions =
2316 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2317 if (any_of(paddingDimensions,
2318 [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2319 return emitOpError() << "expects padding_dimensions to contain positive "
2320 "integers, found "
2321 << getPaddingDimensions();
2322 }
2323 if (!getMixedPadToMultipleOf().empty()) {
2324 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2325 return emitOpError() << "expects as many multiples as padding_dimensions";
2326 }
2327 }
2328 ArrayAttr transposes = getTransposePaddings();
2329 for (Attribute attr : transposes) {
2330 SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2331 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2332 if (!std::is_permutation(sequence.begin(), sequence.end(),
2333 transpose.begin(), transpose.end())) {
2334 return emitOpError()
2335 << "expects transpose_paddings to be a permutation, found "
2336 << attr;
2337 }
2338 }
2339 if (getCopyBackOp() !=
2340 bufferization::MaterializeInDestinationOp::getOperationName() &&
2341 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2342 getCopyBackOp() != kCopyOpNone)
2343 return emitOpError() << "invalid copy_back_op";
2344 return success();
2345}
2346
2347//===---------------------------------------------------------------------===//
2348// PadTilingInterfaceOp
2349//===---------------------------------------------------------------------===//
2350
2351void transform::PadTilingInterfaceOp::build(OpBuilder &b,
2352 OperationState &result,
2353 Value target,
2354 ArrayRef<int64_t> paddingSizes,
2355 bool padToMultipleOf) {
2356 auto resultType = transform::AnyOpType::get(b.getContext());
2357 return build(/*odsBuilder=*/b,
2358 /*result=*/result,
2359 /*types=*/TypeRange{resultType, resultType},
2360 /*target=*/target,
2361 /*padding_values=*/ArrayAttr(), // let inference handle this
2362 /*padding_sizes=*/ValueRange{},
2363 /*paddingSizes=*/
2364 (paddingSizes.empty() ? DenseI64ArrayAttr()
2365 : b.getDenseI64ArrayAttr(paddingSizes)),
2366 /*pad_to_multiple_of=*/
2367 padToMultipleOf ? b.getUnitAttr() : nullptr);
2368}
2369
2370void transform::PadTilingInterfaceOp::build(
2371 OpBuilder &b, OperationState &result, Value target,
2372 ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
2373 auto resultType = transform::AnyOpType::get(b.getContext());
2374 SmallVector<int64_t> staticPaddingSizes;
2375 SmallVector<Value> dynamicPaddingSizes;
2376 dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
2377 staticPaddingSizes);
2378 return build(/*odsBuilder=*/b,
2379 /*result=*/result,
2380 /*types=*/TypeRange{resultType, resultType},
2381 /*target=*/target,
2382 /*padding_values=*/ArrayAttr(), // let inference handle this
2383 /*padding_sizes=*/dynamicPaddingSizes,
2384 /*paddingSizes=*/staticPaddingSizes,
2385 /*usePrescribedTensorShapes=*/padToMultipleOf);
2386}
2387
2388void transform::PadTilingInterfaceOp::getEffects(
2389 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2390 consumesHandle(getTargetMutable(), effects);
2391 onlyReadsHandle(getPaddingSizesMutable(), effects);
2392 producesHandle(getOperation()->getOpResults(), effects);
2393 modifiesPayload(effects);
2394}
2395
2396SmallVector<OpFoldResult>
2397transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2398 Builder b(getContext());
2399 return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2400}
2401
2402DiagnosedSilenceableFailure
2403transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2404 transform::TransformResults &results,
2405 transform::TransformState &state) {
2406 SmallVector<Operation *> paddedOps, padOps;
2407
2408 for (Operation *target : state.getPayloadOps(getTarget())) {
2409 auto targetOp = dyn_cast<TilingInterface>(target);
2410 if (!targetOp) {
2411 auto diag = emitSilenceableError() << "expected TilingInterface target";
2412 diag.attachNote(target->getLoc()) << "target op";
2413 return diag;
2414 }
2415
2416 // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
2417 // loopsToOperand map / C++ APIs to compute the effect of padding on
2418 // operands.
2419 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2420 auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
2421 "supported atm";
2422 diag.attachNote(target->getLoc()) << "target op";
2423 return diag;
2424 }
2425
2426 // Convert the padding values to attributes.
2427 SmallVector<Attribute> paddingValues;
2428 for (auto const &[untypedAttr, elementOrTensorType] :
2429 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2430 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2431 Type elementType = getElementTypeOrSelf(elementOrTensorType);
2432
2433 if (isa<ub::PoisonAttr>(untypedAttr)) {
2434 paddingValues.push_back(untypedAttr);
2435 continue;
2436 }
2437 if (!attr) {
2438 emitOpError("expects padding values to be typed attributes or poison");
2440 }
2441 // Try to parse string attributes to obtain an attribute of element type.
2442 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2443 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2444 stringAttr, getContext(), elementType,
2445 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2446 if (!parsedAttr || parsedAttr.getType() != elementType) {
2447 auto diag = this->emitOpError("expects a padding that parses to ")
2448 << elementType << ", got " << attr;
2449 diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2451 }
2452 paddingValues.push_back(parsedAttr);
2453 continue;
2454 }
2455 // Otherwise, add the attribute directly.
2456 if (attr.getType() != elementType) {
2457 auto diag = this->emitOpError("expects a padding value of type ")
2458 << elementType << ", got " << attr;
2459 diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2461 }
2462 paddingValues.push_back(attr);
2463 }
2464
2465 // Set options.
2466 PadTilingInterfaceOptions options;
2467 options.setPaddingValues(paddingValues)
2468 .setPaddingSizes(getMixedPaddingSizes())
2469 .setPadToMultipleOf(getPadToMultipleOf());
2470
2471 OpBuilder::InsertionGuard g(rewriter);
2472 rewriter.setInsertionPointAfter(targetOp);
2473 auto maybePadOps = rewriteAsPaddedOp(
2474 rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
2475 if (failed(maybePadOps)) {
2476 auto diag = emitSilenceableError() << "failed to pad op";
2477 diag.attachNote(target->getLoc()) << "target op";
2478 return diag;
2479 }
2480 const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
2481
2482 // Set transform results.
2483 paddedOps.push_back(paddedOp);
2484 padOps.append(paddedOperands.begin(), paddedOperands.end());
2485 rewriter.replaceOp(targetOp.getOperation(), slicedResults);
2486 }
2487
2488 results.set(cast<OpResult>(getPadded()), paddedOps);
2489 results.set(cast<OpResult>(getPad()), padOps);
2491}
2492
2493LogicalResult transform::PadTilingInterfaceOp::verify() { return success(); }
2494
2495//===---------------------------------------------------------------------===//
2496// HoistPadOp
2497//===---------------------------------------------------------------------===//
2498
2499DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2500 transform::TransformRewriter &rewriter,
2501 transform::TransformResults &transformResults,
2502 transform::TransformState &state) {
2503 auto targetOps = state.getPayloadOps(getTarget());
2504 auto loopOps = state.getPayloadOps(getLoop());
2505 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2506 return emitDefiniteFailure()
2507 << "requires exactly one target and one loop handle (got "
2508 << llvm::range_size(targetOps) << " and "
2509 << llvm::range_size(loopOps) << ")";
2510 }
2511
2512 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2513 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2514 if (!padOp || !loopOp)
2515 return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2516
2517 FailureOr<linalg::detail::PackingResult> result =
2518 linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2519 getTranspose());
2520 if (failed(result))
2521 return emitDefiniteFailure() << "could not build packing loop nest";
2522
2523 if (result->clonedLoopIvs.empty()) {
2524 transformResults.set(cast<OpResult>(getPackingLoop()),
2525 {result->hoistedPadOp.getOperation()});
2527 }
2528 auto outerPackedLoop =
2529 scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2530 transformResults.set(cast<OpResult>(getPackingLoop()),
2531 {outerPackedLoop.getOperation()});
2533}
2534
2535LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
2536 ArrayRef<int64_t> transpose = getTranspose();
2537 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2538 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2539 transpose.end())) {
2540 return emitOpError() << "expects transpose to be a permutation, found "
2541 << getTranspose();
2542 }
2543 return success();
2544}
2545
2546void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2547 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2548 transform::onlyReadsHandle(getTargetMutable(), effects);
2549 transform::onlyReadsHandle(getLoopMutable(), effects);
2550 transform::producesHandle(getOperation()->getOpResults(), effects);
2552}
2553
2554DiagnosedSilenceableFailure
2555transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2556 tensor::PadOp target,
2557 transform::ApplyToEachResultList &results,
2558 transform::TransformState &state) {
2559 tensor::PadOp hoistedPadOp;
2560 SmallVector<TransposeOp> transposeOps;
2561 FailureOr<Value> result =
2562 hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2563 hoistedPadOp, transposeOps);
2564 if (succeeded(result)) {
2565 // We need to perform our own replacement here because this API is still
2566 // used in patterns that "pad and hoist", for which the replacement values
2567 // need to be different.
2568 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2569 // that we have more composable abstractions.
2570 rewriter.replaceOp(target, *result);
2571 results.push_back(hoistedPadOp);
2573 }
2574 return emitDefaultSilenceableFailure(target);
2575}
2576
2577LogicalResult transform::HoistPadOp::verify() {
2578 ArrayRef<int64_t> transpose = getTranspose();
2579 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2580 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2581 transpose.end())) {
2582 return emitOpError() << "expects transpose to be a permutation, found "
2583 << getTranspose();
2584 }
2585 return success();
2586}
2587
2588//===----------------------------------------------------------------------===//
2589// PromoteOp
2590//===----------------------------------------------------------------------===//
2591
2592DiagnosedSilenceableFailure
2593transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2594 LinalgOp target,
2595 transform::ApplyToEachResultList &results,
2596 transform::TransformState &state) {
2597 LinalgPromotionOptions promotionOptions;
2598 if (!getOperandsToPromote().empty())
2599 promotionOptions = promotionOptions.setOperandsToPromote(
2600 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2601 if (getUseFullTilesByDefault())
2602 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2603 getUseFullTilesByDefault());
2604 if (getUseOriginalSubviewSize())
2605 promotionOptions =
2606 promotionOptions.setUseOriginalSubviewSize(getUseOriginalSubviewSize());
2607 if (getUseAlloca())
2608 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2609 if (!getUseFullTileBuffers().empty())
2610 promotionOptions = promotionOptions.setUseFullTileBuffers(
2611 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2612 if (getAlignment().has_value())
2613 promotionOptions = promotionOptions.setAlignment(*getAlignment());
2614 if (getMemorySpace().has_value())
2615 promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2616
2617 if (getMapping().has_value()) {
2618 // The mapping should only contain an element
2619 auto mapping = *getMapping();
2620 if (mapping.size() > 1)
2621 return emitDefaultDefiniteFailure(target);
2622
2623 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2624
2625 if (addressSpace.getAddressSpace() ==
2626 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2627 promotionOptions =
2628 promotionOptions
2632 .setUseFullTileBuffers({false, false});
2633 } else if (addressSpace.getAddressSpace() ==
2634 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2635 promotionOptions =
2636 promotionOptions
2640 .setUseFullTileBuffers({false, false});
2641 } else {
2642 return emitDefaultDefiniteFailure(target);
2643 }
2644 }
2645
2646 if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2647 return emitDefaultDefiniteFailure(target);
2648
2649 rewriter.setInsertionPoint(target);
2650 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2651 if (failed(res))
2652 return emitDefaultDefiniteFailure(target);
2653 results.push_back(target);
2655}
2656
2657//===----------------------------------------------------------------------===//
2658// ReplaceOp
2659//===----------------------------------------------------------------------===//
2660
2661DiagnosedSilenceableFailure
2662transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2663 TransformResults &transformResults,
2664 TransformState &state) {
2665 auto payload = state.getPayloadOps(getTarget());
2666
2667 // Check for invalid targets.
2668 for (Operation *target : payload) {
2669 if (target->getNumOperands() > 0)
2670 return emitDefiniteFailure() << "expected target without operands";
2671 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2672 target->getNumRegions() > 0)
2673 return emitDefiniteFailure()
2674 << "expected target that is isolated from above";
2675 }
2676
2677 // Clone and replace.
2678 Operation *pattern = &getBodyRegion().front().front();
2679 SmallVector<Operation *> replacements;
2680 for (Operation *target : payload) {
2681 if (getOperation()->isAncestor(target))
2682 continue;
2683 rewriter.setInsertionPoint(target);
2684 Operation *replacement = rewriter.clone(*pattern);
2685 rewriter.replaceOp(target, replacement->getResults());
2686 replacements.push_back(replacement);
2687 }
2688 transformResults.set(cast<OpResult>(getReplacement()), replacements);
2690}
2691
2692void transform::ReplaceOp::getEffects(
2693 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2694 consumesHandle(getTargetMutable(), effects);
2695 producesHandle(getOperation()->getOpResults(), effects);
2696 modifiesPayload(effects);
2697}
2698
2699LogicalResult transform::ReplaceOp::verify() {
2700 if (!getBodyRegion().hasOneBlock())
2701 return emitOpError() << "expected one block";
2702 if (std::distance(getBodyRegion().front().begin(),
2703 getBodyRegion().front().end()) != 1)
2704 return emitOpError() << "expected one operation in block";
2705 Operation *replacement = &getBodyRegion().front().front();
2706 if (replacement->getNumOperands() > 0)
2707 return replacement->emitOpError()
2708 << "expected replacement without operands";
2709 if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2710 replacement->getNumRegions() > 0)
2711 return replacement->emitOpError()
2712 << "expect op that is isolated from above";
2713 return success();
2714}
2715
2716//===----------------------------------------------------------------------===//
2717// ScalarizeOp
2718//===----------------------------------------------------------------------===//
2719
2720DiagnosedSilenceableFailure
2721transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2722 LinalgOp target,
2723 transform::ApplyToEachResultList &results,
2724 transform::TransformState &state) {
2725 scf::SCFTilingOptions tilingOptions;
2726 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2727 SmallVector<OpFoldResult> tileSizes;
2728 Location loc = target.getLoc();
2729 SmallVector<OpFoldResult> allShapeSizes =
2730 target.createFlatListOfOperandDims(b, loc);
2731 AffineMap map = target.getShapesToLoopsMap();
2732 if (!map)
2733 return tileSizes;
2734 SmallVector<OpFoldResult> shapeSizes =
2736 allShapeSizes);
2737 // If the shape size is dynamic, tile by 1.
2738 // Otherwise, do not tile (i.e. tile size 0).
2739 for (OpFoldResult shapeSize : shapeSizes) {
2740 tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2741 : b.getIndexAttr(1));
2742 }
2743 return tileSizes;
2744 });
2745 rewriter.setInsertionPoint(target);
2746 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2747 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2748 if (failed(maybeTilingResult))
2749 return emitDefaultDefiniteFailure(target);
2750
2751 if (target->getNumResults())
2752 rewriter.replaceOp(target, maybeTilingResult->replacements);
2753 else
2754 rewriter.eraseOp(target);
2755
2756 results.reserve(maybeTilingResult->tiledOps.size());
2757 for (Operation *tiled : maybeTilingResult->tiledOps)
2758 results.push_back(tiled);
2760}
2761
2762//===----------------------------------------------------------------------===//
2763// ConvertToLoopsOp
2764//===----------------------------------------------------------------------===//
2765
2766DiagnosedSilenceableFailure
2767transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2768 transform::TransformResults &results,
2769 transform::TransformState &state) {
2770 SmallVector<Operation *> loops;
2771 for (Operation *target : state.getPayloadOps(getTarget())) {
2772 auto tilingOp = dyn_cast<TilingInterface>(*target);
2773 if (!tilingOp) {
2774 DiagnosedSilenceableFailure diag =
2775 emitSilenceableError()
2776 << "expected the payload to implement TilingInterface";
2777 diag.attachNote(target->getLoc()) << "payload op";
2778 return diag;
2779 }
2780 rewriter.setInsertionPoint(target);
2781 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2782 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2783 if (failed(generatedLoops))
2784 return emitDefaultDefiniteFailure(target);
2785 for (scf::ForOp &loop : *generatedLoops) {
2786 loops.push_back(loop.getOperation());
2787 }
2788 rewriter.eraseOp(target);
2789 }
2790 results.set(cast<OpResult>(getResult()), loops);
2792}
2793
2794//===----------------------------------------------------------------------===//
2795// RewriteInDestinationPassingStyleOp
2796//===----------------------------------------------------------------------===//
2797
2798DiagnosedSilenceableFailure
2799transform::RewriteInDestinationPassingStyleOp::applyToOne(
2800 transform::TransformRewriter &rewriter, Operation *target,
2801 transform::ApplyToEachResultList &results,
2802 transform::TransformState &state) {
2803 rewriter.setInsertionPoint(target);
2804 FailureOr<Operation *> maybeResult =
2806 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2807 [&rewriter](auto op) {
2808 return rewriteInDestinationPassingStyle(rewriter, op);
2809 });
2810 if (failed(maybeResult))
2811 return emitDefaultSilenceableFailure(target);
2812 results.push_back(*maybeResult);
2814}
2815
2816//===----------------------------------------------------------------------===//
2817// SplitOp
2818//===----------------------------------------------------------------------===//
2819
2820DiagnosedSilenceableFailure
2821SplitOp::apply(transform::TransformRewriter &rewriter,
2822 TransformResults &results, TransformState &state) {
2823 // Collect the dynamic split points if provided.
2824 SmallVector<Operation *> payload =
2825 llvm::to_vector(state.getPayloadOps(getTarget()));
2826
2827 bool isMultiwaySplit = getMultiway();
2828
2829 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2830 return mlir::emitSilenceableFailure(getLoc())
2831 << "requires exactly one target when "
2832 "multiway split is enabled (got "
2833 << llvm::range_size(payload) << ")";
2834 }
2835
2836 SmallVector<OpFoldResult> chunkSizes;
2837
2838 if (!isMultiwaySplit)
2839 chunkSizes.reserve(payload.size());
2840
2841 if (getDynamicChunkSizes()) {
2843 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2844 chunkSizes = llvm::to_vector(llvm::map_range(
2845 state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2846 if (op->getNumResults() != 1 ||
2847 !op->getResult(0).getType().isIndex()) {
2848 diag = emitSilenceableError()
2849 << "expected dynamic split point handle to point to a "
2850 "single-result index-typed op";
2851 diag.attachNote(op->getLoc()) << "dynamic split point";
2852 }
2853 return OpFoldResult(op->getResult(0));
2854 }));
2855 } else {
2856 chunkSizes = llvm::to_vector(
2857 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2858 [](Attribute attr) { return OpFoldResult(attr); }));
2859 }
2860 if (diag.isSilenceableFailure())
2861 return diag;
2862
2863 // For multiway split, a single payload is expected to have multiple
2864 // split points.
2865 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2866 return emitDefiniteFailure()
2867 << "expected the dynamic split point handle to point to as "
2868 "many operations ("
2869 << chunkSizes.size() << ") as the target handle ("
2870 << payload.size() << ")";
2871 }
2872 } else {
2873 chunkSizes.resize(payload.size(),
2874 rewriter.getIndexAttr(getStaticChunkSizes()));
2875 }
2876
2877 auto checkStructuredOpAndDimensions =
2878 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2879 if (!linalgOp) {
2880 auto diag = emitSilenceableError() << "only applies to structured ops";
2881 diag.attachNote(loc) << "target op";
2882 return diag;
2883 }
2884
2885 if (getDimension() >= linalgOp.getNumLoops()) {
2886 auto diag = emitSilenceableError() << "dimension " << getDimension()
2887 << " does not exist in target op";
2888 diag.attachNote(loc) << "target op";
2889 return diag;
2890 }
2892 };
2893
2894 auto checkFailureInSplitting =
2895 [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2896 if (hasFailed) {
2897 auto diag = emitDefiniteFailure() << "internal failure in splitting";
2898 diag.attachNote(loc) << "target op";
2899 return diag;
2900 }
2902 };
2903
2904 SmallVector<Operation *> opList;
2905 if (isMultiwaySplit) {
2906
2907 // Split a single target operation at multiple points.
2908 TilingInterface head, tail;
2909 Operation *target = payload.front();
2910
2911 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2912
2913 // Check that the target is a valid LinalgOp with correct dimensions.
2914 DiagnosedSilenceableFailure diag =
2915 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2916 if (diag.isSilenceableFailure())
2917 return diag;
2918
2919 for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2920
2921 if (idx > 0)
2922 target = tail.getOperation();
2923
2924 if (!target)
2925 break;
2926
2927 linalgOp = cast<LinalgOp>(target);
2928 Location loc = target->getLoc();
2929
2930 rewriter.setInsertionPoint(linalgOp);
2931 std::tie(head, tail) = linalg::splitOp(
2932 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2933 getDimension(), chunkSize);
2934
2935 // Propagate errors.
2936 DiagnosedSilenceableFailure diag =
2937 checkFailureInSplitting(!head && !tail, loc);
2938 if (diag.isDefiniteFailure())
2939 return diag;
2940
2941 opList.push_back(head.getOperation());
2942 }
2943
2944 // Append any leftover parts to the end of the result list.
2945 if (tail)
2946 opList.push_back(tail.getOperation());
2947
2948 } else {
2949 // Split each target operation.
2950 SmallVector<Operation *> first, second;
2951 Operation *noSecondPart = nullptr;
2952 for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2953 Operation *target = std::get<0>(pair);
2954 Location loc = target->getLoc();
2955 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2956 DiagnosedSilenceableFailure diag =
2957 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2958
2959 if (diag.isSilenceableFailure())
2960 return diag;
2961
2962 rewriter.setInsertionPoint(linalgOp);
2963 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2964 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2965 getDimension(), std::get<1>(pair));
2966
2967 // Propagate errors.
2968 DiagnosedSilenceableFailure diagSplit =
2969 checkFailureInSplitting(!first.back() && !second.back(), loc);
2970 if (diagSplit.isDefiniteFailure())
2971 return diag;
2972
2973 // Do not add null second parts.
2974 if (!second.back()) {
2975 noSecondPart = target;
2976 second.pop_back();
2977 }
2978 }
2979
2980 if (second.size() != first.size() && !second.empty()) {
2981 auto diag = emitSilenceableError()
2982 << "splitting does not produce the second part for a subset "
2983 "of targets";
2984 diag.attachNote()
2985 << "expected splitting to produce the second part of all "
2986 "or none of the targets";
2987 diag.attachNote(noSecondPart->getLoc())
2988 << "first target with no second part";
2989 return diag;
2990 }
2991
2992 opList.append(first);
2993 if (!second.empty())
2994 opList.append(second);
2995 }
2996 results.set(cast<OpResult>(getSplitList()), opList);
2998}
2999
3000void SplitOp::getEffects(
3001 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3002 consumesHandle(getTargetMutable(), effects);
3003 if (getDynamicChunkSizes())
3004 onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
3005 producesHandle(getOperation()->getOpResults(), effects);
3006 modifiesPayload(effects);
3007}
3008
3009ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
3010 OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
3011 IntegerAttr staticChunkSizes;
3012 if (parser.parseOperand(target) || parser.parseKeyword("after"))
3013 return failure();
3014
3015 OptionalParseResult dynamicPointParseResult =
3016 parser.parseOptionalOperand(dynamicChunkSizes);
3017 if (!dynamicPointParseResult.has_value()) {
3018 int64_t staticChunkSizesValue;
3019 if (failed(parser.parseInteger(staticChunkSizesValue)))
3020 return failure();
3021
3022 staticChunkSizes =
3023 parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
3024 }
3025
3026 Type targetType;
3027 if (parser.parseOptionalAttrDict(result.attributes) ||
3028 parser.parseColonType(targetType) ||
3029 parser.resolveOperand(target, targetType, result.operands)) {
3030 return failure();
3031 }
3032 if (dynamicPointParseResult.has_value()) {
3033 Type chunkSizesType;
3034 if (failed(*dynamicPointParseResult) || parser.parseComma() ||
3035 parser.parseType(chunkSizesType) ||
3036 parser.resolveOperand(dynamicChunkSizes, chunkSizesType,
3037 result.operands)) {
3038 return failure();
3039 }
3040
3041 staticChunkSizes =
3042 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
3043 }
3044
3045 result.addAttribute(
3046 SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
3047 staticChunkSizes);
3048 result.addTypes(targetType);
3049 return success();
3050}
3051
3052void SplitOp::print(OpAsmPrinter &printer) {
3053 printer << " " << getTarget() << " after ";
3054 int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
3055 if (staticChunkSize != ShapedType::kDynamic)
3056 printer << staticChunkSize;
3057 else
3058 printer << getDynamicChunkSizes();
3059 printer << " ";
3060 printer.printOptionalAttrDict(getOperation()->getAttrs(),
3061 {getStaticChunkSizesAttrName()});
3062 printer << " : " << getTarget().getType();
3063 if (staticChunkSize == ShapedType::kDynamic)
3064 printer << ", " << getDynamicChunkSizes().getType();
3065}
3066
3067LogicalResult SplitOp::verify() {
3068 if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
3069 (getDynamicChunkSizes() == nullptr)) {
3070 return emitOpError() << "expects either a dynamic or a static split "
3071 "point to be provided";
3072 }
3073 return success();
3074}
3075
3076//===----------------------------------------------------------------------===//
3077// SplitReductionOp
3078//===----------------------------------------------------------------------===//
3079
3080void transform::SplitReductionOp::build(
3081 OpBuilder &builder, OperationState &result, Value target,
3082 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
3083 bool useScalingAlgorithm, bool useAlloc) {
3084 MLIRContext *ctx = builder.getContext();
3085 result.addOperands(target);
3086 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
3087 builder.getI64IntegerAttr(splitFactor));
3088 result.addAttribute(
3089 SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
3090 builder.getI64IntegerAttr(insertSplitDimension));
3091 if (innerParallel) {
3092 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
3093 builder.getUnitAttr());
3094 }
3095 if (useScalingAlgorithm) {
3096 result.addAttribute(
3097 SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
3098 builder.getUnitAttr());
3099 }
3100 if (useAlloc) {
3101 result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
3102 builder.getUnitAttr());
3103 }
3104 auto resultType = transform::AnyOpType::get(ctx);
3105 result.addTypes({resultType, resultType, resultType, resultType});
3106}
3107
3108DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
3109 transform::TransformRewriter &rewriter, LinalgOp target,
3110 transform::ApplyToEachResultList &results,
3111 transform::TransformState &state) {
3112 ControlSplitReductionFn splitFn = [&](LinalgOp) {
3113 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
3114 unsigned(getInsertSplitDimension()),
3115 bool(getInnerParallel())};
3116 };
3117 rewriter.setInsertionPoint(target);
3118 FailureOr<SplitReductionResult> splitResult =
3119 (getUseScalingAlgorithm())
3120 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
3121 : splitReduction(rewriter, target, splitFn, getUseAlloc());
3122 if (failed(splitResult))
3123 return emitDefaultDefiniteFailure(target);
3124
3125 results.push_back(splitResult->initOrAlloc);
3126 results.push_back(splitResult->fillOp);
3127 results.push_back(splitResult->splitLinalgOp);
3128 results.push_back(splitResult->resultCombiningLinalgOp);
3130}
3131
3132//===----------------------------------------------------------------------===//
3133// TileReductionUsingForOp
3134//===----------------------------------------------------------------------===//
3135
3136void transform::TileReductionUsingForOp::build(
3137 OpBuilder &builder, OperationState &result, Value target,
3138 ArrayRef<int64_t> staticTileSizes) {
3139 // Call the default builder.
3140 // This is future-proof re mixed static-dynamic and setting up the proper
3141 // operands segment sizes attributes for multiple variadic operands.
3142 // In the absence of this, horrible bugs ensue.
3143 // TODO: support mixed static-dynamic (see TileUsingForallOp).
3144 MLIRContext *ctx = builder.getContext();
3145 auto opTy = transform::AnyOpType::get(ctx);
3146 auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
3147 build(builder, result,
3148 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3149 /*target=*/target,
3150 /*reduction_dims=*/nullptr,
3151 /*tile_sizes=*/staticTileSizesAttr);
3152}
3153
3154DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
3155 transform::TransformRewriter &rewriter, Operation *target,
3156 transform::ApplyToEachResultList &results,
3157 transform::TransformState &state) {
3158 rewriter.setInsertionPoint(target);
3159
3160 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3161 if (!partialReductionOp) {
3163 target->getLoc(),
3164 "Operation should implement PartialReductionOpInterface");
3165 }
3166
3167 SmallVector<unsigned> reductionDims =
3168 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3169 if (reductionDims.empty()) {
3170 for (auto [idx, iteratorType] :
3171 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3172 if (iteratorType == utils::IteratorType::reduction)
3173 reductionDims.push_back(idx);
3174 }
3175 }
3176
3177 scf::SCFTilingOptions options;
3178 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
3179 options.setReductionTilingStrategy(
3181 options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
3182 options.setReductionDims(reductionDims);
3183 FailureOr<scf::SCFTilingResult> result =
3184 scf::tileUsingSCF(rewriter, partialReductionOp, options);
3185
3186 if (failed(result)) {
3187 return emitSilenceableFailure(getLoc(),
3188 "failed to tile using partial reduction");
3189 }
3190 rewriter.replaceOp(target, result->replacements);
3191 for (Value initValue : result->initialValues)
3192 results.push_back(initValue.getDefiningOp());
3193 for (auto parallelTiledOp : result->tiledOps)
3194 results.push_back(parallelTiledOp);
3195 for (auto mergeOp : result->mergeOps)
3196 results.push_back(mergeOp);
3197 results.push_back(result->loops.front());
3199}
3200
3201//===----------------------------------------------------------------------===//
3202// TileReductionUsingForallOp
3203//===----------------------------------------------------------------------===//
3204
3205void transform::TileReductionUsingForallOp::build(
3206 OpBuilder &builder, OperationState &result, Value target,
3207 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3208 ArrayAttr mapping) {
3209 // Call the default builder.
3210 // This is future-proof re mixed static-dynamic and setting up the proper
3211 // operands segment sizes attributes for multiple variadic operands.
3212 // In the absence of this, horrible bugs ensue.
3213 // TODO: support mixed static-dynamic (see TileUsingForallOp).
3214 MLIRContext *ctx = builder.getContext();
3215 auto opTy = transform::AnyOpType::get(ctx);
3216 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3217 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3218 build(builder, result,
3219 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3220 /*target=*/target,
3221 /*reduction_dims=*/{},
3222 /*num_threads=*/staticNumThreadsAttr,
3223 /*tile_sizes=*/staticTileSizesAttr,
3224 /*mapping=*/mapping);
3225}
3226
3227DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3228 transform::TransformRewriter &rewriter, Operation *target,
3229 transform::ApplyToEachResultList &results,
3230 transform::TransformState &state) {
3231 rewriter.setInsertionPoint(target);
3232
3233 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3234 if (!partialReductionOp) {
3236 target->getLoc(),
3237 "Operation should implement PartialReductionOpInterface");
3238 }
3239 SmallVector<OpFoldResult> numThreads =
3240 getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
3241 SmallVector<OpFoldResult> tileSizes =
3243
3244 scf::SCFTilingOptions options;
3245 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3246 options.setReductionTilingStrategy(
3248 if (!getNumThreads().empty()) {
3249 options.setNumThreads(numThreads);
3250 } else {
3251 options.setTileSizes(tileSizes);
3252 }
3253 if (auto mapping = getMapping()) {
3254 options.setMapping(mapping.value().getValue());
3255 }
3256 SmallVector<unsigned> reductionDims =
3257 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3258 if (reductionDims.empty()) {
3259 for (auto [idx, iteratorType] :
3260 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3261 if (iteratorType == utils::IteratorType::reduction)
3262 reductionDims.push_back(idx);
3263 }
3264 }
3265 options.setReductionDims(reductionDims);
3266 FailureOr<scf::SCFTilingResult> result =
3267 scf::tileUsingSCF(rewriter, partialReductionOp, options);
3268
3269 if (failed(result)) {
3270 auto diag = emitSilenceableError() << "could not tile reduction";
3271 return diag;
3272 }
3273 rewriter.replaceOp(target, result->replacements);
3274
3275 for (Value initValue : result->initialValues)
3276 results.push_back(initValue.getDefiningOp());
3277 for (auto parallelTiledOp : result->tiledOps)
3278 results.push_back(parallelTiledOp);
3279 for (auto mergeOp : result->mergeOps)
3280 results.push_back(mergeOp);
3281 results.push_back(result->loops.front());
3283}
3284
3285//===----------------------------------------------------------------------===//
3286// ContinuousTileSizesOp
3287//===----------------------------------------------------------------------===//
3288
3289DiagnosedSilenceableFailure
3290transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3291 TransformResults &transformResults,
3292 TransformState &state) {
3293
3294 SmallVector<Operation *> targetOps =
3295 llvm::to_vector(state.getPayloadOps(getTarget()));
3296
3297 if (!llvm::hasSingleElement(targetOps)) {
3298 return mlir::emitSilenceableFailure(getLoc())
3299 << "requires exactly one target (got " << llvm::range_size(targetOps)
3300 << ")";
3301 }
3302
3303 Operation *target = *targetOps.begin();
3304 auto linalgOp = dyn_cast<LinalgOp>(target);
3305 auto tileableOp = dyn_cast<TilingInterface>(target);
3306
3307 if (!linalgOp)
3308 return emitDefiniteFailure() << "expected Linalg Op";
3309
3310 OpBuilder builder(linalgOp.getContext());
3311
3312 if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
3313 if (linalgOp.hasDynamicShape()) {
3314 auto diag = emitSilenceableError()
3315 << "cannot compute parametric tile sizes for dynamically "
3316 "shaped payload op";
3317 diag.attachNote(linalgOp->getLoc()) << "payload op";
3318 return diag;
3319 }
3320
3321 FailureOr<StaticContinuousTileSizeSpecification> spec =
3322 computeStaticContinuousTileSizes(linalgOp, getDimension(),
3323 getTargetSize());
3324 if (failed(spec)) {
3325 return emitSilenceableError()
3326 << "failed to compute multi-size tiling sizes";
3327 }
3328
3329 SmallVector<int64_t> chunkSizes;
3330
3331 for (auto &&[tileSize, tripCount] :
3332 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3333 chunkSizes.push_back(tileSize * tripCount);
3334
3335 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3336 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3337 return builder.getI64IntegerAttr(value);
3338 });
3339 };
3340 transformResults.setParams(cast<OpResult>(getTileSizes()),
3341 getI64AttrsFromI64(spec->tileSizes));
3342 transformResults.setParams(cast<OpResult>(getChunkSizes()),
3343 getI64AttrsFromI64(chunkSizes));
3344
3346 }
3347
3348 builder.setInsertionPoint(linalgOp);
3349
3350 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
3351 unsigned dimension = getDimension();
3352
3353 FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
3354 builder, tileableOp, dimension, targetSize, true);
3355 if (failed(spec)) {
3356 return emitSilenceableError() << "could not generate tile size computation";
3357 }
3358
3359 AffineExpr s0 = builder.getAffineSymbolExpr(0);
3360 AffineExpr s1 = builder.getAffineSymbolExpr(1);
3361 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3362 return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
3363 ofrs);
3364 };
3365
3366 SmallVector<Value> chunkSizes;
3367 Value splitPoint;
3368 for (auto &&[tileSize, tripCount] :
3369 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3370 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3371 chunkSizes.push_back(splitPoint);
3372 }
3373
3374 auto getDefiningOps = [&](ArrayRef<Value> values) {
3375 return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3376 return value.getDefiningOp();
3377 });
3378 };
3379
3380 transformResults.set(cast<OpResult>(getTileSizes()),
3381 getDefiningOps(spec->tileSizes));
3382 transformResults.set(cast<OpResult>(getChunkSizes()),
3383 getDefiningOps(chunkSizes));
3384
3386}
3387
3388LogicalResult transform::ContinuousTileSizesOp::verify() {
3389
3390 if (getTileSizes().getType() != getChunkSizes().getType()) {
3391 return emitOpError() << "expects all results type to be the same";
3392 }
3393
3394 return success();
3395}
3396
3397void transform::ContinuousTileSizesOp::getEffects(
3398 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3399 if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
3400 onlyReadsPayload(effects);
3401 else
3402 modifiesPayload(effects);
3403 onlyReadsHandle(getTargetMutable(), effects);
3404 producesHandle(getOperation()->getOpResults(), effects);
3405}
3406
3408 Type targetType, Type tileSizes,
3409 Type) {
3410 printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes});
3411}
3412
3414 Type &targetType,
3415 Type &tileSizesType,
3416 Type &chunkSizesType) {
3417 FunctionType funcType;
3418 llvm::SMLoc typeLoc = parser.getCurrentLocation();
3419 if (failed(parser.parseType<FunctionType>(funcType)))
3420 return failure();
3421
3422 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3423 parser.emitError(typeLoc) << "expects a trailing functional type with one "
3424 "argument and one result";
3425 }
3426 targetType = funcType.getInput(0);
3427 tileSizesType = chunkSizesType = funcType.getResult(0);
3428
3429 return success();
3430}
3431
3432//===----------------------------------------------------------------------===//
3433// TileUsingForOp
3434//===----------------------------------------------------------------------===//
3435
3436void transform::TileUsingForOp::build(
3437 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3438 Value target, ArrayRef<int64_t> staticTileSizes,
3439 ArrayRef<int64_t> interchange,
3440 std::optional<ArrayRef<bool>> scalableSizes) {
3441 return build(builder, result, loopTypes,
3442 /*target=*/target,
3443 /*mixedTileSizes=*/
3444 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3445 interchange, scalableSizes);
3446}
3447
3448void transform::TileUsingForOp::build(
3449 OpBuilder &builder, OperationState &result, Value target,
3450 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3451 std::optional<ArrayRef<bool>> scalableSizes) {
3452 build(builder, result, target,
3453 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3454 interchange, scalableSizes);
3455}
3456
3457void transform::TileUsingForOp::build(
3458 OpBuilder &builder, OperationState &result, Value target,
3459 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3460 std::optional<ArrayRef<bool>> scalableSizes) {
3461 // Loop types are automaticaly splat by the callee, setting up one is
3462 // enough.
3463 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
3464 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3465 scalableSizes);
3466}
3467
3468void transform::TileUsingForOp::build(
3469 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3470 Value target, ArrayRef<OpFoldResult> mixedTileSizes,
3471 ArrayRef<int64_t> interchange,
3472 std::optional<ArrayRef<bool>> scalableSizes) {
3473 SmallVector<int64_t> staticTileSizes;
3474 SmallVector<Value> dynamicTileSizes;
3475 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3476 // Call the default builder which sets up the proper operands segment sizes
3477 // attributes for multiple variadic operands. In the absence of this,
3478 // horrible bugs ensue.
3479 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3480 unsigned numExpectedLoops =
3481 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3482 SmallVector<Type> resultTypes;
3483 resultTypes.reserve(numExpectedLoops);
3484 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3485 "expected one loop type or as many as loops");
3486 if (loopTypes.size() == 1)
3487 resultTypes.append(numExpectedLoops, loopTypes[0]);
3488 else
3489 llvm::append_range(resultTypes, loopTypes);
3490 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
3491 if (scalableSizes.has_value())
3492 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3493 build(builder, result, /*tiled_linalg_op=*/target.getType(),
3494 /*loops=*/resultTypes,
3495 /*target=*/target,
3496 /*dynamic_sizes=*/dynamicTileSizes,
3497 /*static_sizes=*/staticTileSizesAttr,
3498 /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
3499 /*scalable_sizes=*/expandedScalableSizes);
3500}
3501
3502LogicalResult transform::TileUsingForOp::verify() {
3503 if (getMixedSizes().size() != getScalableSizes().size())
3504 return emitOpError("expected same number of sizes (")
3505 << getMixedSizes().size() << ") and scalable sizes ("
3506 << getScalableSizes().size() << ")";
3507 ArrayRef<int64_t> staticSizes = getStaticSizes();
3508 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3509 if (getLoops().size() != numExpectedLoops)
3510 return emitOpError("expected number of loops to tile (")
3511 << numExpectedLoops << ") to match number of `loops` results ("
3512 << getLoops().size() << ")";
3513 return success();
3514}
3515
3516DiagnosedSilenceableFailure
3517transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3518 TransformResults &transformResults,
3519 TransformState &state) {
3520 ArrayRef<int64_t> tileSizes = getStaticSizes();
3521
3522 SmallVector<Operation *> targets =
3523 llvm::to_vector(state.getPayloadOps(getTarget()));
3524 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3525 SmallVector<SmallVector<int64_t>> paramSizes;
3526 dynamicSizeProducers.reserve(getDynamicSizes().size());
3527 paramSizes.reserve(getDynamicSizes().size());
3528 for (Value transformValue : getDynamicSizes()) {
3529 if (isa<ParamType>(transformValue.getType())) {
3530 dynamicSizeProducers.push_back({});
3531 ArrayRef<Attribute> params = state.getParams(transformValue);
3532 paramSizes.push_back(
3533 llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
3534 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3535 })));
3536
3537 if (paramSizes.back().size() != targets.size()) {
3538 DiagnosedSilenceableFailure diag =
3539 emitSilenceableError()
3540 << "expected as many parameter values ("
3541 << dynamicSizeProducers.back().size() << ") as target ops ("
3542 << targets.size() << ")";
3543 diag.attachNote(transformValue.getLoc()) << "for this parameter";
3544 return diag;
3545 }
3546
3547 continue;
3548 }
3549 paramSizes.push_back({});
3550 dynamicSizeProducers.push_back(
3551 llvm::to_vector(state.getPayloadOps(transformValue)));
3552
3553 if (dynamicSizeProducers.back().size() != targets.size()) {
3554 DiagnosedSilenceableFailure diag =
3555 emitSilenceableError()
3556 << "expected as many dynamic size-producing operations ("
3557 << dynamicSizeProducers.back().size() << ") as target ops ("
3558 << targets.size() << ")";
3559 diag.attachNote(transformValue.getLoc()) << "for this handle";
3560 return diag;
3561 }
3562
3563 for (Operation *op : dynamicSizeProducers.back()) {
3564 if (op->getNumResults() == 1 &&
3565 isa<IndexType>(op->getResult(0).getType())) {
3566 continue;
3567 }
3568
3569 DiagnosedSilenceableFailure diag =
3570 emitSilenceableError() << "expected sizes to be produced by ops "
3571 "with a single index-type result";
3572 diag.attachNote(op->getLoc()) << "size producer op";
3573 diag.attachNote(transformValue.getLoc()) << "for this handle";
3574 return diag;
3575 }
3576 }
3577
3578 SmallVector<Operation *> tiled;
3579 SmallVector<SmallVector<Operation *, 4>, 4> loops;
3580 loops.resize(getLoops().size());
3581 auto scalableSizes = getScalableSizes();
3582 for (auto [i, op] : llvm::enumerate(targets)) {
3583 auto tilingInterface = dyn_cast<TilingInterface>(op);
3584 if (!tilingInterface) {
3585 DiagnosedSilenceableFailure diag =
3586 emitSilenceableError()
3587 << "only ops implementing TilingInterface are supported";
3588 diag.attachNote(op->getLoc()) << "target op";
3589 return diag;
3590 }
3591 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3592 DiagnosedSilenceableFailure diag =
3593 emitSilenceableError()
3594 << "too many tiles provided, expected at most "
3595 << tilingInterface.getLoopIteratorTypes().size() << " found "
3596 << tileSizes.size();
3597 diag.attachNote(op->getLoc()) << "target op";
3598 return diag;
3599 }
3600
3601 scf::SCFTilingOptions tilingOptions;
3602 if (tileSizes.empty()) {
3603 tilingOptions.setTileSizeComputationFunction(
3604 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
3605 return {};
3606 });
3607 } else {
3608 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3609 Operation *) {
3610 SmallVector<OpFoldResult> sizes;
3611 sizes.reserve(tileSizes.size());
3612 unsigned dynamicIdx = 0;
3613
3614 for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3615 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3616 if (scalableSizes[ofrIdx]) {
3618 b, getLoc(), cast<IntegerAttr>(attr).getInt());
3619 Value vscale =
3620 vector::VectorScaleOp::create(b, getLoc(), b.getIndexType());
3621 sizes.push_back(
3622 arith::MulIOp::create(b, getLoc(), val, vscale).getResult());
3623 } else {
3624 sizes.push_back(attr);
3625 }
3626 continue;
3627 }
3628 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3629 ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3630 ++dynamicIdx;
3631 assert((dynamicSizes.empty() ^ params.empty()) &&
3632 "expected either dynamic sizes or parameters");
3633 if (!params.empty()) {
3634 sizes.push_back(b.getIndexAttr(params[index]));
3635 } else {
3636 sizes.push_back(dynamicSizes[index]->getResult(0));
3637 }
3638 }
3639 return sizes;
3640 });
3641 }
3642
3643 tilingOptions.setInterchange(getInterchange());
3644 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3645 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3646 if (failed(maybeTilingResult))
3648
3649 rewriter.replaceOp(op, maybeTilingResult->replacements);
3650
3651 tiled.append(maybeTilingResult->tiledOps);
3652 for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3653 loops[en2.index()].push_back(en2.value());
3654 }
3655
3656 transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3657 for (const auto &en : llvm::enumerate(loops))
3658 transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3659
3661}
3662
3663SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
3664 ValueRange dynamic = getDynamicSizes();
3665 ArrayRef<int64_t> tileSizes = getStaticSizes();
3666 SmallVector<OpFoldResult> results;
3667 results.reserve(tileSizes.size());
3668 unsigned dynamicPos = 0;
3669 Builder builder(getContext());
3670 for (int64_t size : tileSizes) {
3671 if (size == ShapedType::kDynamic) {
3672 results.push_back(dynamic[dynamicPos++]);
3673 } else {
3674 results.push_back(builder.getIndexAttr(size));
3675 }
3676 }
3677 return results;
3678}
3679
3680void transform::TileUsingForOp::getEffects(
3681 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3682 consumesHandle(getTargetMutable(), effects);
3683 onlyReadsHandle(getDynamicSizesMutable(), effects);
3684 producesHandle(getOperation()->getOpResults(), effects);
3685 modifiesPayload(effects);
3686}
3687
3688//===----------------------------------------------------------------------===//
3689// TileUsingForallOp
3690//===----------------------------------------------------------------------===//
3691
3692void transform::TileUsingForallOp::build(OpBuilder &builder,
3693 OperationState &result, Value target,
3694 ArrayRef<int64_t> staticTileSizes,
3695 transform::TileSizesSpec,
3696 ArrayAttr mapping) {
3697 return build(builder, result,
3698 /*target=*/target,
3699 /*mixedTileSizes=*/
3700 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3701 /*_=*/TileSizesSpec(),
3702 /*mapping=*/mapping);
3703}
3704
3705void transform::TileUsingForallOp::build(OpBuilder &builder,
3706 OperationState &result, Value target,
3707 ArrayRef<OpFoldResult> mixedTileSizes,
3708 transform::TileSizesSpec,
3709 ArrayAttr mapping) {
3710 SmallVector<int64_t> staticTileSizes;
3711 SmallVector<Value> dynamicTileSizes;
3712 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3713 // Call the default builder which sets up the proper operands segment sizes
3714 // attributes for multiple variadic operands. In the absence of this,
3715 // horrible bugs ensue.
3716 MLIRContext *ctx = builder.getContext();
3717 auto operationType = transform::AnyOpType::get(ctx);
3718 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3719 build(builder, result,
3720 /*resultTypes=*/TypeRange{operationType, operationType},
3721 /*target=*/target,
3722 /*num_threads=*/ValueRange{},
3723 /*tile_sizes=*/dynamicTileSizes,
3724 /*packed_num_threads=*/Value(),
3725 /*packed_tile_sizes=*/Value(),
3726 /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3727 /*static_tile_sizes=*/staticTileSizesAttr,
3728 /*mapping=*/mapping);
3729}
3730
3731void transform::TileUsingForallOp::build(OpBuilder &builder,
3732 OperationState &result, Value target,
3733 ArrayRef<int64_t> staticNumThreads,
3734 transform::NumThreadsSpec,
3735 ArrayAttr mapping) {
3736 return build(builder, result, target,
3737 getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3738 NumThreadsSpec(), mapping);
3739}
3740
3741void transform::TileUsingForallOp::build(OpBuilder &builder,
3742 OperationState &result, Value target,
3743 ArrayRef<OpFoldResult> mixedNumThreads,
3744 transform::NumThreadsSpec,
3745 ArrayAttr mapping) {
3746 SmallVector<int64_t> staticNumThreads;
3747 SmallVector<Value> dynamicNumThreads;
3748 dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3749 staticNumThreads);
3750 // Call the default builder which sets up the proper operands segment sizes
3751 // attributes for multiple variadic operands. In the absence of this,
3752 // horrible bugs ensue.
3753 MLIRContext *ctx = builder.getContext();
3754 auto operationType = transform::AnyOpType::get(ctx);
3755 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3756 build(builder, result,
3757 /*resultTypes=*/TypeRange{operationType, operationType},
3758 /*target=*/target,
3759 /*num_threads=*/dynamicNumThreads,
3760 /*tile_sizes=*/ValueRange{},
3761 /*packed_num_threads=*/Value(),
3762 /*packed_tile_sizes=*/Value(),
3763 /*static_num_threads=*/staticNumThreadsAttr,
3764 /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3765 /*mapping=*/mapping);
3766}
3767
3768/// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3769/// normalized upper bound.
3770static SmallVector<OpFoldResult>
3773 ArrayRef<OpFoldResult> steps) {
3774 AffineExpr s0, s1, s2;
3775 bindSymbols(rewriter.getContext(), s0, s1, s2);
3776 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3777 SmallVector<OpFoldResult> normalizedUbs;
3778 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3780 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3781 normalizedUbs.push_back(normalizedUb);
3782 }
3783 return normalizedUbs;
3784}
3785
3786/// When a loop is normalized, the uses of the induction variable within the
3787/// loop need to replaced with `original_lb + old_iv * original_step`.
3789 Location loc, ValueRange ivs,
3791 ArrayRef<OpFoldResult> steps) {
3792 AffineExpr s0, s1;
3793 AffineExpr d0;
3794 bindSymbols(rewriter.getContext(), s0, s1);
3795 bindDims(rewriter.getContext(), d0);
3796 AffineExpr denormExpr = s0 + d0 * s1;
3797 SmallVector<Value> denormalizedIvs;
3798
3799 for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3801 rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3802 denormalizedIvs.push_back(
3803 getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3804 }
3805 return denormalizedIvs;
3806}
3807
3808/// Given a `scf.forall` loop return a loop op with the loop bounds
3809/// normalized.
3810/// TODO: Replace this with a general utility to normalize `scf.forall`.
3811/// At the time of writing, this wasnt done since adding this to `scf`
3812/// dialect would disallow using of `affine.apply` operations due
3813/// to cyclic dependencies. To avoid churn in lit tests
3814/// with the change this was added with, defer that to a follow up.
3815static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3816 scf::ForallOp loop) {
3817 SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3818 SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3819 SmallVector<OpFoldResult> steps = loop.getMixedStep();
3820
3821 if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
3822 return loop;
3823 }
3824
3825 Location loc = loop.getLoc();
3826 SmallVector<OpFoldResult> normalizedUbs =
3827 normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3828 SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3829 rewriter.getIndexAttr(0));
3830 SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3831 rewriter.getIndexAttr(1));
3832
3833 auto normalizedForallOp = scf::ForallOp::create(
3834 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3835 loop.getOutputs(), loop.getMapping(),
3836 [](OpBuilder &, Location, ValueRange) {});
3837
3838 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3839 OpBuilder::InsertionGuard g(rewriter);
3840 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3841 rewriter.setInsertionPointToStart(normalizedLoopBlock);
3842
3843 SmallVector<Value> argValues =
3844 denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3845 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3846 normalizedForallOp.getRegionIterArgs().end());
3847 Block *origLoopBlock = loop.getBody();
3848 rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3849
3850 rewriter.replaceOp(loop, normalizedForallOp);
3851 return normalizedForallOp;
3852}
3853
3855 RewriterBase &rewriter, transform::TransformState &state,
3856 TransformOpInterface transformOp, Operation *target,
3857 ArrayRef<OpFoldResult> mixedNumThreads,
3858 ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3859 scf::SCFTilingResult &tilingResult) {
3860 // Transform all targets one by one.
3861 auto tileableOp = dyn_cast<TilingInterface>(target);
3862 if (!tileableOp) {
3864 transformOp.emitSilenceableError()
3865 << "only TilingInterface ops are supported";
3866 diag.attachNote(target->getLoc()) << "target op";
3867 return diag;
3868 }
3869 rewriter.setInsertionPoint(tileableOp);
3870 scf::SCFTilingOptions options;
3871 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3872 if (!mixedNumThreads.empty()) {
3873 options.setNumThreads(mixedNumThreads);
3874 } else {
3875 options.setTileSizes(mixedTileSizes);
3876 }
3877 if (mapping) {
3878 options.setMapping(mapping.value().getValue());
3879 }
3880 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3881 scf::tileUsingSCF(rewriter, tileableOp, options);
3882
3883 if (failed(maybeTilingResult))
3884 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3885
3886 rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3887
3888 tilingResult = *maybeTilingResult;
3889
3890 if (mixedNumThreads.empty()) {
3891 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3892 OpBuilder::InsertionGuard g(rewriter);
3893 rewriter.setInsertionPoint(generatedForallOp);
3894 scf::ForallOp normalizedForallOp =
3895 normalizeForallLoopOp(rewriter, generatedForallOp);
3896 tilingResult.loops.front() = normalizedForallOp;
3897 }
3898
3900}
3901
3902DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3904 transform::TransformResults &transformResults,
3906 auto transformOp = cast<TransformOpInterface>(getOperation());
3907
3908 // Result payload ops.
3910 SmallVector<Operation *> tiledOps;
3911
3912 // Unpack handles.
3913 SmallVector<OpFoldResult> mixedNumThreads;
3915 getPackedNumThreads()
3917 state, transformOp, mixedNumThreads, getPackedNumThreads())
3919 state, transformOp, mixedNumThreads, getMixedNumThreads());
3920 if (!status.succeeded())
3921 return status;
3922 SmallVector<OpFoldResult> mixedTileSizes;
3923 status = getPackedTileSizes()
3925 state, transformOp, mixedTileSizes, getPackedTileSizes())
3927 state, transformOp, mixedTileSizes, getMixedTileSizes());
3928 if (!status.succeeded())
3929 return status;
3930
3931 for (Operation *target : state.getPayloadOps(getTarget())) {
3932 scf::SCFTilingResult tilingResult;
3934 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3935 getMapping(), tilingResult);
3936 if (!diag.succeeded())
3937 return diag;
3938 tileOps.push_back(tilingResult.loops.front());
3939 tiledOps.append(tilingResult.tiledOps);
3940 }
3941
3942 transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3943 transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3944
3946}
3947
3948void transform::TileUsingForallOp::getEffects(
3949 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3950 consumesHandle(getTargetMutable(), effects);
3951 onlyReadsHandle(getTileSizesMutable(), effects);
3952 onlyReadsHandle(getNumThreadsMutable(), effects);
3953 onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3954 onlyReadsHandle(getPackedTileSizesMutable(), effects);
3955 producesHandle(getOperation()->getOpResults(), effects);
3956 modifiesPayload(effects);
3957}
3958
3959SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3960 Builder b(getContext());
3961 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3962}
3963
3964SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3965 Builder b(getContext());
3966 return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3967}
3968
3969LogicalResult TileUsingForallOp::verify() {
3970 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3971 static_cast<int>(getPackedNumThreads() != Value());
3972 if (numThreadsSpec > 1)
3973 return emitOpError(
3974 "num_threads and packed_num_threads are mutually exclusive");
3975 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3976 static_cast<int>(getPackedTileSizes() != Value());
3977 if (tileSizesSpec > 1)
3978 return emitOpError(
3979 "tile_sizes and packed_tile_sizes are mutually exclusive");
3980 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3981 return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3982 "must be specified");
3983 return success();
3984}
3985
3986//===----------------------------------------------------------------------===//
3987// VectorizeChildrenAndApplyPatternsOp
3988//===----------------------------------------------------------------------===//
3989
3990void transform::VectorizeChildrenAndApplyPatternsOp::build(
3991 OpBuilder &builder, OperationState &result, Value target,
3992 bool foldTypeExtensionsIntoContract, bool vectorizePadding,
3993 bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3994 result.addOperands(target);
3995 if (foldTypeExtensionsIntoContract) {
3996 result.addAttribute(
3997 VectorizeChildrenAndApplyPatternsOp::
3998 getFoldTypeExtensionsIntoContractAttrName(result.name),
3999 builder.getUnitAttr());
4000 }
4001 if (vectorizePadding) {
4002 result.addAttribute(
4003 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
4004 result.name),
4005 builder.getUnitAttr());
4006 }
4007 if (vectorizeExtract) {
4008 result.addAttribute(
4009 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
4010 result.name),
4011 builder.getUnitAttr());
4012 }
4013 if (flatten1DDepthwiseConv) {
4014 result.addAttribute(
4015 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
4016 result.name),
4017 builder.getUnitAttr());
4018 }
4019 result.addTypes(transform::AnyOpType::get(builder.getContext()));
4020}
4021
4022namespace {
4023/// This is an helper only to call vectorize via a pattern inside of
4024/// VectorizeChildrenAndApplyPatternsOp::applyToOne.
4025struct VectorizationPattern : public RewritePattern {
4026 explicit VectorizationPattern(MLIRContext *context,
4027 bool vectorizeExtract = false,
4028 bool flattenConv = false)
4029 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
4030 vectorizeNDExtract(vectorizeExtract),
4031 flatten1DDepthwiseConv(flattenConv) {}
4032 LogicalResult matchAndRewrite(Operation *op,
4033 PatternRewriter &rewriter) const override {
4035 return rewriter.notifyMatchFailure(op,
4036 "Unsupported Op, cannot vectorize");
4037 FailureOr<VectorizationResult> vectorResults =
4038 vectorize(rewriter, op, /*inputVectorSizes=*/{},
4039 /*inputScalableVecDims=*/{}, vectorizeNDExtract,
4040 flatten1DDepthwiseConv);
4041 if (failed(vectorResults))
4042 return failure();
4043 rewriter.replaceOp(op, vectorResults->replacements);
4044 return success();
4045 }
4046
4047private:
4048 /// Controls whether to vectorize `tensor.extract` when the input tensor is
4049 /// rank >= 2.
4050 bool vectorizeNDExtract = false;
4051 /// Controls whether to "flatten" the channel dimension when vectorising 1D
4052 /// depthwise convolutions. This should lead to bette vectorization for
4053 /// tensors with a low number of channel dimensions.
4054 bool flatten1DDepthwiseConv = false;
4055};
4056} // namespace
4057
4058DiagnosedSilenceableFailure
4059transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
4060 transform::TransformRewriter &rewriter, Operation *target,
4061 transform::ApplyToEachResultList &results,
4062 transform::TransformState &state) {
4063 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
4064 auto diag = this->emitOpError("requires isolated-from-above targets");
4065 diag.attachNote(target->getLoc()) << "non-isolated target";
4067 }
4068
4069 MLIRContext *ctx = getContext();
4070 RewritePatternSet patterns(ctx);
4071 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
4072 getFlatten_1dDepthwiseConv());
4073
4074 if (!getDisableTransferPermutationMapLoweringPatterns())
4076
4077 if (!getDisableMultiReductionToContractPatterns())
4079
4081
4082 patterns.add<linalg::LinalgCopyVTRForwardingPattern,
4083 linalg::LinalgCopyVTWForwardingPattern>(ctx,
4084 /*benefit=*/2);
4085 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
4086 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
4088
4089 patterns.add<CopyVectorizationPattern>(ctx);
4090
4091 if (getFoldTypeExtensionsIntoContract())
4093
4094 if (getVectorizePadding()) {
4096 // This creates an alternative path for lowering tensor.pad - by
4097 // decomposing it into e.g. linalg.fill.
4099 }
4101
4102 TrackingListener listener(state, *this);
4103 if (failed(
4105 GreedyRewriteConfig().setListener(&listener))))
4106 return emitDefaultDefiniteFailure(target);
4107
4108 results.push_back(target);
4110}
4111
4112//===----------------------------------------------------------------------===//
4113// VectorizeOp
4114//===----------------------------------------------------------------------===//
4115
4116DiagnosedSilenceableFailure transform::VectorizeOp::apply(
4117 transform::TransformRewriter &rewriter,
4118 mlir::transform::TransformResults &transformResults,
4119 mlir::transform::TransformState &state) {
4120 auto targets = state.getPayloadOps(getTarget());
4121 if (std::empty(targets))
4123 auto transformOp = cast<TransformOpInterface>(getOperation());
4124 SmallVector<int64_t> vectorSizes;
4125 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
4126 state, transformOp, getMixedVectorSizes(), vectorSizes);
4127 if (!status.succeeded())
4128 return status;
4129
4130 // TODO: Check that the correct number of vectorSizes was provided.
4131 for (Operation *target : targets) {
4133 return mlir::emitSilenceableFailure(target->getLoc())
4134 << "Unsupported Op, cannot vectorize";
4135 }
4136 FailureOr<VectorizationResult> vectorResults =
4137 linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
4138 getVectorizeNdExtract().value_or(false),
4139 /*flatten1DDepthwiseConv=*/false,
4140 getAssumeDynamicDimsMatchVecSizes().value_or(false),
4141 getCreateNamedContraction().value_or(false));
4142 if (failed(vectorResults)) {
4143 return mlir::emitSilenceableFailure(target->getLoc())
4144 << "Attempted to vectorize, but failed";
4145 }
4146 rewriter.replaceOp(target, vectorResults->replacements);
4147 }
4148
4150}
4151
4152void transform::VectorizeOp::getEffects(
4153 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4154 consumesHandle(getTargetMutable(), effects);
4155 onlyReadsHandle(getVectorSizesMutable(), effects);
4156 modifiesPayload(effects);
4157}
4158
4159SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
4160 OpBuilder b(getContext());
4161 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
4162}
4163
4164LogicalResult transform::VectorizeOp::verify() {
4165 if (getStaticVectorSizes().size() != getScalableSizes().size())
4166 return emitOpError("expected same number of vector sizes (")
4167 << getStaticVectorSizes().size() << ") and scalable sizes ("
4168 << getScalableSizes().size() << ")";
4169 return success();
4170}
4171
4172//===----------------------------------------------------------------------===//
4173// HoistRedundantVectorTransfersOp
4174//===----------------------------------------------------------------------===//
4175
4176DiagnosedSilenceableFailure
4177transform::HoistRedundantVectorTransfersOp::applyToOne(
4178 transform::TransformRewriter &rewriter, func::FuncOp target,
4179 transform::ApplyToEachResultList &results,
4180 transform::TransformState &state) {
4181 // WARNING: This hoisting does not model parallelism and is generally
4182 // incorrect when used on distributed loops with memref semantics!
4183 // TODO: obsolete and should be retired.
4184 linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
4185 results.push_back(target);
4187}
4188
4189//===----------------------------------------------------------------------===//
4190// HoistRedundantVectorBroadcastsOp
4191//===----------------------------------------------------------------------===//
4192
4193DiagnosedSilenceableFailure
4194transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4195 transform::TransformRewriter &rewriter, mlir::Operation *target,
4196 transform::ApplyToEachResultList &results,
4197 transform::TransformState &state) {
4198 rewriter.setInsertionPoint(target);
4200 results.push_back(target);
4202}
4203
4204//===----------------------------------------------------------------------===//
4205// ConvertConv2DToImg2ColOp.
4206//===----------------------------------------------------------------------===//
4207
4208DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
4209 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4210 transform::ApplyToEachResultList &results,
4211 transform::TransformState &state) {
4212 rewriter.setInsertionPoint(target);
4213 auto maybeTransformed =
4215 target)
4216 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4217 return rewriteInIm2Col(rewriter, op);
4218 })
4219 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4220 return rewriteInIm2Col(rewriter, op);
4221 })
4222 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4223 return rewriteInIm2Col(rewriter, op);
4224 })
4225 .Case([&](linalg::Conv2DNchwFchwOp op) {
4226 return rewriteInIm2Col(rewriter, op);
4227 })
4228 .Default([&](Operation *op) {
4229 return rewriter.notifyMatchFailure(op, "not supported");
4230 });
4231 if (failed(maybeTransformed))
4232 return emitDefaultSilenceableFailure(target);
4233 // Handle to the operation producing the img2col tensor.
4234 results.push_back(maybeTransformed->first);
4235 // Handle to the operation that replaces the original convolution.
4236 results.push_back(maybeTransformed->second);
4238}
4239
4240//===----------------------------------------------------------------------===//
4241// FlattenElementwiseLinalgOp.
4242//===----------------------------------------------------------------------===//
4243
4244DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4245 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4246 transform::ApplyToEachResultList &results,
4247 transform::TransformState &state) {
4248 rewriter.setInsertionPoint(target);
4249 if (!isElementwise(target))
4250 return mlir::emitSilenceableFailure(target->getLoc())
4251 << "only elementwise flattening is supported";
4252
4253 // If rank <= 1, do nothing
4254 if (target.getNumLoops() <= 1) {
4255 results.push_back(target);
4257 }
4258
4259 // Attempt to flatten all dims to one.
4260 ReassociationIndices reassociation(target.getNumLoops());
4261 std::iota(reassociation.begin(), reassociation.end(), 0);
4262 auto maybeFlattened =
4263 collapseOpIterationDims(target, reassociation, rewriter);
4264 if (failed(maybeFlattened))
4265 return mlir::emitSilenceableFailure(target->getLoc())
4266 << "attempted to flatten, but failed";
4267 results.push_back(maybeFlattened->collapsedOp);
4268 rewriter.replaceOp(target, maybeFlattened->results);
4270}
4271
4272//===----------------------------------------------------------------------===//
4273// TransposeConv2DOp
4274//===----------------------------------------------------------------------===//
4275
4276DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4277 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4278 transform::ApplyToEachResultList &results,
4279 transform::TransformState &state) {
4280 rewriter.setInsertionPoint(target);
4281 auto maybeTransformed =
4283 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4284 return transposeConv2D(rewriter, op);
4285 })
4286 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4287 return transposeConv2D(rewriter, op);
4288 })
4289 .Default([&](Operation *op) {
4290 return rewriter.notifyMatchFailure(op, "not supported");
4291 });
4292 if (failed(maybeTransformed))
4293 return emitDefaultSilenceableFailure(target);
4294 // Handle to the new Conv2D operation with transposed filters
4295 results.push_back(*maybeTransformed);
4297}
4298
4299//===----------------------------------------------------------------------===//
4300// TransposeMatmulOp
4301//===----------------------------------------------------------------------===//
4302
4303DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4304 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4305 transform::ApplyToEachResultList &results,
4306 transform::TransformState &state) {
4307 rewriter.setInsertionPoint(target);
4308 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4309 auto maybeTransformed =
4311 .Case([&](linalg::MatmulOp op) {
4312 return transposeMatmul(rewriter, op, transposeLHS);
4313 })
4314 .Case([&](linalg::BatchMatmulOp op) {
4315 return transposeBatchMatmul(rewriter, op, transposeLHS);
4316 })
4317 .Default([&](Operation *op) { return failure(); });
4318 if (failed(maybeTransformed))
4319 return emitSilenceableFailure(target->getLoc()) << "not supported";
4320 // Handle to the new Matmul operation with transposed filters
4321 results.push_back(*maybeTransformed);
4323}
4324
4325//===----------------------------------------------------------------------===//
4326// InsertSliceToCopyOp
4327//===----------------------------------------------------------------------===//
4328template <typename OpTy>
4329static DiagnosedSilenceableFailure
4330doit(RewriterBase &rewriter, OpTy target,
4333 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4334 tensor::ParallelInsertSliceOp>() &&
4335 "wrong op type");
4336
4337 if (auto copySource =
4338 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4339 results.push_back(copySource);
4341 }
4342
4343 // If we are inside a `ParallelCombiningOp` region, temporarily set the
4344 // insertion point outside: only ops implementing ParallelCombiningOpInterface
4345 // are allowed in there.
4346 if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
4347 rewriter.setInsertionPoint(target->getParentOp());
4348
4349 Value extracted = tensor::ExtractSliceOp::create(
4350 rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
4351 target.getMixedSizes(), target.getMixedStrides());
4352 Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
4353 target.getSource(), extracted)
4354 .getResult(0);
4355 // Reset the insertion point.
4356 rewriter.setInsertionPoint(target);
4357 rewriter.replaceOpWithNewOp<OpTy>(
4358 target, copied, target.getDest(), target.getMixedOffsets(),
4359 target.getMixedSizes(), target.getMixedStrides());
4360
4361 results.push_back(copied.getDefiningOp());
4363}
4364
4365DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4366 transform::TransformRewriter &rewriter, Operation *targetOp,
4367 transform::ApplyToEachResultList &results,
4368 transform::TransformState &state) {
4369
4370 rewriter.setInsertionPoint(targetOp);
4371 if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4372 return doit(rewriter, target, results, state);
4373 if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4374 return doit(rewriter, target, results, state);
4375
4376 DiagnosedSilenceableFailure diag =
4377 emitSilenceableError()
4378 << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4379 diag.attachNote(targetOp->getLoc()) << "target op";
4380 return diag;
4381}
4382
4383//===----------------------------------------------------------------------===//
4384// MapCopyToThreadsOp
4385//===----------------------------------------------------------------------===//
4386
4387DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4388 transform::TransformRewriter &rewriter, Operation *target,
4389 transform::ApplyToEachResultList &results,
4390 transform::TransformState &state) {
4391 // Check if the op is supported.
4392 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4393 DiagnosedSilenceableFailure diag =
4394 emitSilenceableError()
4395 << "only linalg.copy and tensor.pad target ops are supported";
4396 diag.attachNote(target->getLoc()) << "target op";
4397 return diag;
4398 }
4399 assert(target->getNumResults() == 1 && "expected single result");
4400 auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
4401 if (!resultShapedType.hasStaticShape()) {
4402 DiagnosedSilenceableFailure diag =
4403 emitSilenceableError()
4404 << "only statically sized ops of rank <= 3 are supported";
4405 diag.attachNote(target->getLoc()) << "target op";
4406 return diag;
4407 }
4408
4409 // Conservatively set the minimum viable desired bitwidth alignment.
4410 int64_t desiredBitAlignment = getDesiredBitAlignment();
4411 int64_t eltBitwidth =
4412 resultShapedType.getElementType().getIntOrFloatBitWidth();
4413 if (desiredBitAlignment % eltBitwidth != 0) {
4414 desiredBitAlignment = eltBitwidth;
4415 }
4416
4417 gpu::CopyMappingInfo mapping(
4418 /*ctx=*/getContext(),
4419 /*totalNumThreads=*/getTotalNumThreads(),
4420 /*alignment=*/desiredBitAlignment,
4421 /*sizes=*/resultShapedType.getShape(),
4422 /*favorPredication=*/false,
4423 /*elementalBitwidth=*/
4424 resultShapedType.getElementType().getIntOrFloatBitWidth());
4425 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4426 DiagnosedSilenceableFailure diag =
4427 emitSilenceableError()
4428 << "too few threads to map copy op to threads on the most minor "
4429 "dimension, given alignment and vector size constraints, try "
4430 "smaller tile size of mapping to more threads";
4431 diag.attachNote(target->getLoc()) << "target op";
4432 return diag;
4433 }
4434
4435 // OpBuilder only used to compute attributes.
4436 OpBuilder b(getContext());
4437 scf::SCFTilingResult tilingResult;
4438 DiagnosedSilenceableFailure diag = tileToForallOpImpl(
4439 /*rewriter=*/rewriter,
4440 /*state=*/state,
4441 /*transformOp=*/*this,
4442 /*target=*/target,
4443 /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
4444 /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
4445 /*mapping=*/b.getArrayAttr(mapping.threadMapping),
4446 /*tilingResult=*/tilingResult);
4447 if (!diag.succeeded())
4448 return diag;
4449
4450 results.push_back(tilingResult.loops.front());
4451 for (auto op : tilingResult.tiledOps)
4452 results.push_back(op);
4454}
4455
4456//===----------------------------------------------------------------------===//
4457// WinogradConv2DOp
4458//===----------------------------------------------------------------------===//
4459
4460DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4461 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4462 transform::ApplyToEachResultList &results,
4463 transform::TransformState &state) {
4464 rewriter.setInsertionPoint(target);
4465 FailureOr<Operation *> maybeTransformed = failure();
4466 bool supported = TypeSwitch<Operation *, bool>(target)
4467 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4468 maybeTransformed =
4469 winogradConv2D(rewriter, op, getFmr());
4470 return true;
4471 })
4472 .Default([&](Operation *op) { return false; });
4473
4474 if (!supported) {
4475 return emitSilenceableError()
4476 << "this operation is not supported to convert to Winograd Conv2D";
4477 }
4478
4479 if (failed(maybeTransformed)) {
4480 return emitSilenceableError() << "apply Winograd Conv2D failed";
4481 }
4482
4483 results.push_back(*maybeTransformed);
4485}
4486
4487DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4488 transform::TransformRewriter &rewriter, Operation *target,
4489 transform::ApplyToEachResultList &results,
4490 transform::TransformState &state) {
4491 rewriter.setInsertionPoint(target);
4492 FailureOr<Operation *> maybeTransformed = failure();
4493 bool supported =
4495 .Case([&](linalg::WinogradFilterTransformOp op) {
4496 maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
4497 return true;
4498 })
4499 .Case([&](linalg::WinogradInputTransformOp op) {
4500 maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
4501 return true;
4502 })
4503 .Case([&](linalg::WinogradOutputTransformOp op) {
4504 maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
4505 return true;
4506 })
4507 .Default(false);
4508
4509 if (!supported) {
4510 DiagnosedSilenceableFailure diag =
4511 emitSilenceableError()
4512 << "this operation is not supported to decompose into other operations";
4513 diag.attachNote(target->getLoc()) << "target op";
4514 return diag;
4515 }
4516
4517 if (failed(maybeTransformed)) {
4518 DiagnosedSilenceableFailure diag =
4519 emitSilenceableError() << "decompose Winograd operations failed";
4520 diag.attachNote(target->getLoc()) << "target op";
4521 return diag;
4522 }
4523
4524 results.push_back(*maybeTransformed);
4526}
4527
4528#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4529
4530#define GET_OP_CLASSES
4531#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized. TODO: Replace this with a g...
#define DOWNSCALE_NORMAL(a, b)
ArrayAttr()
static bool mayBeRead(OpOperand &operand)
Return true if the operand may be read from by its owner.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static bool sameOrEquivalentIterArg(Value src, Value dst)
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tileSizes, Type)
static DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static llvm::ManagedStatic< PassManagerOptions > options
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
Definition AffineExpr.h:68
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
UnitAttr getUnitAttr()
Definition Builders.cpp:98
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition Builders.cpp:306
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the error.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:337
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition Builders.h:316
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:457
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getOpResult(unsigned idx)
Definition Operation.h:421
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
void setOperand(unsigned idx, Value value)
Definition Operation.h:351
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:263
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
result_range getOpResults()
Definition Operation.h:420
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
Definition Padding.cpp:244
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition Tiling.cpp:856
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition Transforms.h:489
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:215
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition Tiling.cpp:236
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition Tiling.cpp:156
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition Tiling.cpp:106
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition Hoisting.cpp:89
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition Hoisting.cpp:198
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition Split.cpp:67
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition Tiling.cpp:262
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition SCF.cpp:745
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SmallVector< IntTy > extractFromIntegerArrayAttr(Attribute attr)
Extract integer values from the assumed ArrayAttr of IntegerAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition Builders.h:285
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition Transforms.h:521
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition Transforms.h:409
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition Transforms.h:456
LinalgPromotionOptions & setAlignment(unsigned align)
Definition Transforms.h:433
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition Transforms.h:440
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition Transforms.h:398
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition Transforms.h:420
LinalgPromotionOptions & setUseOriginalSubviewSize(bool originalSize)
Definition Transforms.h:427
LinalgPromotionOptions & setUseAlloca(bool use)
Definition Transforms.h:446
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition Transforms.h:466