MLIR 22.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
12
41#include "mlir/Support/LLVM.h"
43#include "llvm/ADT/STLExtras.h"
44#include "llvm/ADT/ScopeExit.h"
45#include "llvm/ADT/SmallPtrSet.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/DebugLog.h"
48#include "llvm/Support/LogicalResult.h"
49#include <type_traits>
50
51using namespace mlir;
52using namespace mlir::linalg;
53using namespace mlir::transform;
54
55#define DEBUG_TYPE "linalg-transforms"
56
57/// Attempts to apply the pattern specified as template argument to the given
58/// operation. The pattern is expected to have a `returningMatchAndRewrite`
59/// function that returns the "main" result or failure. Returns failure if the
60/// pattern failed to apply. Extra arguments are forwarded to the pattern
61/// constructor.
62template <typename PatternTy, typename... Args>
63static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
64 // Check if the given operation has the type expected by the pattern.
65 using OpTy = typename llvm::function_traits<
66 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
67 auto op = dyn_cast<OpTy>(operation);
68 if (!op)
69 return failure();
70
71 // Apply the pattern directly to the op.
72 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
73 // We want to discourage direct use of PatternRewriter in APIs but In this
74 // very specific case, an IRRewriter is not enough.
75 PatternRewriter rewriter(operation->getContext());
76 rewriter.setInsertionPoint(operation);
77 auto result = pattern.returningMatchAndRewrite(op, rewriter);
78 if (failed(result))
79 return failure();
80 return cast<LinalgOp>(result->getOperation());
81}
82
83/// Assuming that `ofr` is an index attr or a param of index type
84/// or a transform dialect handle mapped to exactly one op
85/// with one index result, return that value.
87 transform::TransformState &state, TransformOpInterface transformOp,
89 for (OpFoldResult ofr : ofrs) {
90 if (auto attr = dyn_cast<Attribute>(ofr)) {
91 if (!isa<IntegerAttr>(attr))
92 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
93 result.push_back(ofr);
94 continue;
95 }
96
97 Value transformValue = cast<Value>(ofr);
98 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
99 ArrayRef<Attribute> params = state.getParams(transformValue);
100 if (params.size() != 1)
101 return transformOp.emitDefiniteFailure()
102 << "requires exactly one parameter associated";
103 result.push_back(params[0]);
104 continue;
105 }
106
107 auto payloadOps = state.getPayloadOps(transformValue);
108 if (!llvm::hasSingleElement(payloadOps)) {
110 transformOp.emitSilenceableError()
111 << "handle must be mapped to exactly one payload op";
112 diag.attachNote(transformValue.getLoc())
113 << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
114 return diag;
115 }
116
117 Operation *op = *payloadOps.begin();
118 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
120 transformOp.emitSilenceableError()
121 << "payload op must have exactly 1 index result";
122 diag.attachNote(op->getLoc())
123 << "has " << op->getNumResults() << " results";
124 return diag;
125 }
126 result.push_back(op->getResult(0));
127 }
128
130}
131
132// Given a list of params that are index attrs or a list of OpFoldResults
133// that are either index attrs or op handles, return a list of OpFoldResults
134// of index attrs or a list of OpFoldResults where all op handles are
135// replaced with the first (and only) OpResult of that payload op.
136// (There must be exactly one parameter associated with the AnyParamType or
137// one mapped payload op which must have exactly one index result.)
139 transform::TransformState &state, TransformOpInterface transformOp,
140 SmallVector<OpFoldResult> &result, Value packedHandle) {
141 if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
142 ArrayRef<Attribute> params = state.getParams(packedHandle);
143 for (auto param : params) {
144 if (!isa<IntegerAttr>(param))
145 return transformOp.emitDefiniteFailure()
146 << "expected the parameter to be associated with an integer "
147 "attribute";
148 result.push_back(param);
149 }
151 }
152
153 for (Operation *op : state.getPayloadOps(packedHandle)) {
154 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
156 transformOp.emitSilenceableError()
157 << "payload op must have exactly 1 index result";
158 diag.attachNote(op->getLoc())
159 << "has " << op->getNumResults() << " results";
160 return diag;
161 }
162 result.push_back(op->getResult(0));
163 }
164
166}
167
168/// When possible, converts each `OpFoldResult` in `mixedResult` to
169/// an integer if the value can be statically inferred. If a result
170/// is a `Value` then it must be either a `ParamType` or a handle
171/// to an a constant like op.
173 TransformState &state, TransformOpInterface &transformOp,
174 ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
175 for (OpFoldResult paramOrHandle : mixedResults) {
176 if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
177 reified.push_back(cast<IntegerAttr>(attr).getInt());
178 continue;
179 } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
180 ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
181 if (params.size() != 1)
182 return transformOp.emitSilenceableError() << "expected a single param";
183 reified.push_back(
184 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
185 continue;
186 }
187
188 Value handle = cast<Value>(paramOrHandle);
189 if (!isa<TransformHandleTypeInterface>(handle.getType()))
190 return transformOp.emitSilenceableError() << "unexpected value handle";
191 auto payload = state.getPayloadOps(handle);
192 if (!llvm::hasSingleElement(payload))
193 return transformOp.emitSilenceableError()
194 << "requires param or handle that is mapped to 1 payload op";
195
196 Operation *paramOrHandlePayloadOp = *payload.begin();
197 if (paramOrHandlePayloadOp->getNumResults() != 1 ||
198 !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
199 return transformOp.emitSilenceableError()
200 << "requires param or handle to be result of op with 1 index "
201 "result";
202 }
203
204 IntegerAttr attr;
205 if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
206 return transformOp.emitSilenceableError()
207 << "requires param or handle to be the result of a constant like "
208 "op";
209
210 reified.push_back(attr.getInt());
211 }
213}
214
215//===----------------------------------------------------------------------===//
216// Apply...PatternsOp
217//===----------------------------------------------------------------------===//
218
219void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
222}
223
224void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
227}
228
229void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
232}
233
234void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
238}
239
240void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
243 options.rankReductionStrategy =
246}
247
248void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
251}
252
253void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
256}
257
258void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
261}
262
263void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
266}
267
268void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
271}
272
273//===----------------------------------------------------------------------===//
274// BufferizeToAllocationOp
275//===----------------------------------------------------------------------===//
276
277namespace {
278class NewOpsListener : public RewriterBase::ForwardingListener {
279public:
281
282 SmallVector<Operation *> getNewOps() const {
283 return SmallVector<Operation *>(newOps.begin(), newOps.end());
284 }
285
286private:
287 void notifyOperationInserted(Operation *op,
288 OpBuilder::InsertPoint previous) override {
289 ForwardingListener::notifyOperationInserted(op, previous);
290 // We only care about newly created ops.
291 if (previous.isSet())
292 return;
293 auto inserted = newOps.insert(op);
294 (void)inserted;
295 assert(inserted.second && "expected newly created op");
296 }
297
298 void notifyOperationErased(Operation *op) override {
299 ForwardingListener::notifyOperationErased(op);
300 op->walk([&](Operation *op) { newOps.erase(op); });
301 }
302
304};
305} // namespace
306
307DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
310 // Attach listener to keep track of newly created ops.
311 OpBuilder::Listener *previousListener = rewriter.getListener();
312 auto resetListener =
313 llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
314 NewOpsListener newOpsListener(previousListener);
315 rewriter.setListener(&newOpsListener);
316
318 if (getMemcpyOp() == "bufferization.materialize_in_destination") {
319 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
320 MaterializeInDestination;
321 } else if (getMemcpyOp() == "memref.copy") {
322 options.memcpyOp =
324 } else if (getMemcpyOp() == "linalg.copy") {
325 options.memcpyOp =
327 } else {
328 llvm_unreachable("invalid memcpy op");
329 }
330 if (getAllocOp() == "memref.alloc") {
331 options.allocOp =
333 } else if (getAllocOp() == "memref.alloca") {
334 options.allocOp =
336 } else {
337 llvm_unreachable("invalid alloc op");
338 }
339 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
340 options.emitDealloc = getEmitDealloc();
341
342 // Bufferize ops.
343 Attribute memorySpace =
344 getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
345 SmallVector<Value> allocatedBuffers;
346 for (Operation *op : state.getPayloadOps(getTarget())) {
347 Value buffer =
348 linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
349 if (!buffer) {
350 DiagnosedSilenceableFailure diag = emitSilenceableError()
351 << "failed to bufferize operation";
352 diag.attachNote(op->getLoc()) << "target payload op";
353 return diag;
354 }
355 allocatedBuffers.push_back(buffer);
356 }
357
358 // Set results.
359 results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
360 results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
362}
363
364void transform::BufferizeToAllocationOp::getEffects(
366 if (getBufferizeDestinationOnly()) {
367 // The destination is replaced with a newly allocated buffer, but the op
368 // itself remains in place.
369 onlyReadsHandle(getTargetMutable(), effects);
370 } else {
371 consumesHandle(getTargetMutable(), effects);
372 }
373 producesHandle(getOperation()->getOpResults(), effects);
374 modifiesPayload(effects);
375}
376
377LogicalResult transform::BufferizeToAllocationOp::verify() {
378 if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
379 getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
380 return emitOpError() << "unsupported memcpy op";
381 if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
382 return emitOpError() << "unsupported alloc op";
383 return success();
384}
385
386//===----------------------------------------------------------------------===//
387// PromoteTensorOp
388//===----------------------------------------------------------------------===//
389
390/// Return true if the operand may be read from by its owner. This is currently
391/// very conservative and only looks inside linalg operations to prevent
392/// unintentional data loss.
393static bool mayBeRead(OpOperand &operand) {
394 auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.getOwner());
395
396 // Be conservative about ops we cannot analyze deeper.
397 if (!linalgOp)
398 return true;
399
400 // Look inside linalg ops.
401 Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
402 return !blockArgument.use_empty();
403}
404
405/// Return true if the value may be read through any of its uses.
406static bool mayBeRead(Value value) {
407 // If the value has a reference semantics, it
408 // may be read through any alias...
409 if (!isa<TensorType, FloatType, IntegerType>(value.getType()))
410 return true;
411 return llvm::any_of(value.getUses(),
412 static_cast<bool (&)(OpOperand &)>(mayBeRead));
413}
414
416transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
419 SmallVector<Value> promoted;
420 for (Value tensor : state.getPayloadValues(getTensor())) {
421 auto type = dyn_cast<RankedTensorType>(tensor.getType());
422 if (!type) {
423 return emitSilenceableError() << "non-tensor type: " << tensor;
424 }
425
426 Operation *definingOp = tensor.getDefiningOp();
427 if (definingOp)
428 rewriter.setInsertionPointAfter(definingOp);
429 else
430 rewriter.setInsertionPointToStart(cast<BlockArgument>(tensor).getOwner());
431
432 // Check this before we emit operations using this value.
433 bool needsMaterialization = mayBeRead(tensor);
434
435 SmallVector<Value> dynamicDims;
437 for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
438 if (!ShapedType::isDynamic(dim))
439 continue;
440 Value cst =
441 arith::ConstantIndexOp::create(rewriter, tensor.getLoc(), pos);
442 auto dimOp =
443 tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst);
444 preservedOps.insert(dimOp);
445 dynamicDims.push_back(dimOp);
446 }
447 auto allocation = bufferization::AllocTensorOp::create(
448 rewriter, tensor.getLoc(), type, dynamicDims);
449 // Set memory space if provided.
450 if (getMemorySpaceAttr())
451 allocation.setMemorySpaceAttr(getMemorySpaceAttr());
452 Value allocated = allocation;
453
454 // Only insert a materialization (typically bufferizes to a copy) when the
455 // value may be read from.
456 if (needsMaterialization) {
457 auto copy = bufferization::MaterializeInDestinationOp::create(
458 rewriter, tensor.getLoc(), tensor, allocated);
459 preservedOps.insert(copy);
460 promoted.push_back(copy.getResult());
461 } else {
462 promoted.push_back(allocated);
463 }
464 rewriter.replaceAllUsesExcept(tensor, promoted.back(), preservedOps);
465 }
466 results.setValues(cast<OpResult>(getPromoted()), promoted);
468}
469
470void transform::PromoteTensorOp::getEffects(
472 transform::onlyReadsHandle(getTensorMutable(), effects);
473 transform::producesHandle(getOperation()->getOpResults(), effects);
475}
476
477//===----------------------------------------------------------------------===//
478// DecomposeOp
479//===----------------------------------------------------------------------===//
480
482transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
483 LinalgOp target,
486#define DOWNSCALE(trans) \
487 { \
488 FailureOr<LinalgOp> res = tryApply<trans>(target); \
489 if (succeeded(res)) { \
490 results.push_back(*res); \
491 return DiagnosedSilenceableFailure::success(); \
492 } \
493 }
494
495#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
496#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
497
498 DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
499 DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
500 DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
501 DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
502 DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
503 DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
504 DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
505 DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
506 DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
509#undef DOWNSCALE_NORMAL
510#undef DOWNSCALE_CALL
511#undef DOWNSCALE
512 return emitDefaultSilenceableFailure(target);
513}
514
515//===----------------------------------------------------------------------===//
516// DecomposeInterfaceOp
517//===----------------------------------------------------------------------===//
518
519// Decompose the target operation if it implements the AggregatedOpInterface.
520// Push the decomposed operations (the ones that replaces the values produced by
521// \p target) in the `results`.
522DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
526 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
527 if (!decomposableOp) {
529 "payload is not a decomposable op"));
530 return emitDefaultSilenceableFailure(target);
531 }
532
533 FailureOr<SmallVector<Value>> maybeNewResults =
534 decomposableOp.decomposeOperation(rewriter);
535 if (failed(maybeNewResults))
536 return emitDefaultSilenceableFailure(target);
537
538 rewriter.replaceOp(decomposableOp, *maybeNewResults);
539 for (Value val : *maybeNewResults) {
540 Operation *definition = val.getDefiningOp();
541 if (definition)
542 results.push_back(definition);
543 }
545}
546
547//===----------------------------------------------------------------------===//
548// EliminateLinalgOpAnchoredEmptyTensorsOp
549//===----------------------------------------------------------------------===//
550
551void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
553 onlyReadsHandle(getTargetMutable(), effects);
554 modifiesPayload(effects);
555}
556
558transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
559 transform::TransformRewriter &rewriter, TransformResults &transformResults,
560 TransformState &state) {
562 options.allowReturnAllocsFromLoops = true;
563
564 for (Operation *target : state.getPayloadOps(getTarget())) {
566 if (failed(analyzeOp(target, state)))
567 return mlir::emitSilenceableFailure(target->getLoc())
568 << "failed to analyze op";
570 rewriter, target, state)))
571 return mlir::emitSilenceableFailure(target->getLoc())
572 << "failed to eliminate LinalgOp anchored tensor.empty ops";
573 }
575}
576
577//===----------------------------------------------------------------------===//
578// FuseOp
579//===----------------------------------------------------------------------===//
580
581void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
582 TypeRange loopTypes, Value target,
583 ArrayRef<int64_t> staticTileSizes,
584 ArrayRef<int64_t> staticTileInterchange,
585 bool applyCleanup, bool useForall) {
586 return build(
587 builder, result, loopTypes,
588 /*target=*/target,
589 /*mixedTileSizes=*/
590 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
591 /*mixedTileInterchange=*/
592 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
593 applyCleanup, useForall);
594}
595
596void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
597 Value target, ArrayRef<int64_t> staticTileSizes,
598 ArrayRef<int64_t> staticTileInterchange,
599 bool applyCleanup, bool useForall) {
600 return build(
601 builder, result,
602 /*target=*/target,
603 /*mixedTileSizes=*/
604 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
605 /*mixedTileInterchange=*/
606 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileInterchange)),
607 applyCleanup, useForall);
608}
609
610void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
612 ArrayRef<OpFoldResult> mixedTileSizes,
613 ArrayRef<OpFoldResult> mixedTileInterchange,
614 bool applyCleanup, bool useForall) {
615 // Loop types are automaticaly splat by the callee, setting up one is
616 // enough.
617 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
618 build(builder, result, loopTypes, target, mixedTileSizes,
619 mixedTileInterchange, applyCleanup, useForall);
620}
621
622void transform::FuseOp::build(OpBuilder &builder, OperationState &result,
623 TypeRange loopTypes, Value target,
624 ArrayRef<OpFoldResult> mixedTileSizes,
625 ArrayRef<OpFoldResult> mixedTileInterchange,
626 bool applyCleanup, bool useForall) {
627 SmallVector<int64_t> staticTileSizes;
628 SmallVector<Value> dynamicTileSizes;
629 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
630 SmallVector<int64_t> staticTileInterchange;
631 SmallVector<Value> dynamicTileInterchange;
632 dispatchIndexOpFoldResults(mixedTileInterchange, dynamicTileInterchange,
633 staticTileInterchange);
634 // Call the default builder which sets up the proper operands segment sizes
635 // attributes for multiple variadic operands. In the absence of this,
636 // horrible bugs ensue.
637 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
638 auto staticTileInterchangeAttr =
639 builder.getDenseI64ArrayAttr(staticTileInterchange);
640 unsigned numExpectedLoops =
641 useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
642 SmallVector<Type> resultTypes;
643 resultTypes.reserve(numExpectedLoops);
644 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
645 "expected one loop type or as many as loops");
646 if (loopTypes.size() == 1)
647 resultTypes.append(numExpectedLoops, loopTypes[0]);
648 else
649 llvm::append_range(resultTypes, loopTypes);
650 build(builder, result, /*transformed=*/target.getType(),
651 /*loops=*/resultTypes,
652 /*target=*/target,
653 /*tile_sizes=*/dynamicTileSizes,
654 /*tile_interchange=*/dynamicTileInterchange,
655 /*static_tile_sizes=*/staticTileSizesAttr,
656 /*static_tile_interchange=*/staticTileInterchangeAttr,
657 /*apply_cleanup=*/applyCleanup,
658 /*use_forall=*/useForall);
659}
660
661/// Apply a tiling transformation to all payload ops and store both the
662/// tiled operation as well as the created tile loops.
663template <typename Range>
664static LogicalResult applyTilingToAll(
665 RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
666 unsigned numLoops, transform::TransformResults &transformResults,
667 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
668 applyFn) {
669 SmallVector<Operation *> tiledLinalgOps;
670 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
671
672 for (Operation *target : payloadOps) {
673 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
674 if (!tilingInterfaceOp)
675 return transformOp->emitError("only TilingInterface ops are supported");
676
677 rewriter.setInsertionPoint(target);
678 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
679 applyFn(tilingInterfaceOp);
680 if (failed(tiledResults))
681 return failure();
682
683 // Perform the replacement of tiled and fused values.
684 SmallVector<Operation *> opsToReplace{target};
685 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
686 for (Operation *toReplace : opsToReplace) {
687 for (OpResult res : toReplace->getResults())
688 if (auto replacement = tiledResults->replacements.lookup(res))
689 rewriter.replaceAllUsesWith(res, replacement);
690 if (toReplace->use_empty()) {
691 rewriter.eraseOp(toReplace);
692 }
693 }
694
695 // Report back the relevant handles to the transform op.
696 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
697 assert(tiledResults->loops.size() == numLoops &&
698 "Mismatched number of loops, tile and fuse transform should have "
699 "failed");
700 for (unsigned int i = 0; i < numLoops; ++i)
701 loopOps[i].push_back(tiledResults->loops[i]);
702 }
703
704 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
705 for (unsigned int i = 0; i < numLoops; ++i)
706 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
707
708 return success();
709}
710
712transform::FuseOp::apply(transform::TransformRewriter &rewriter,
713 mlir::transform::TransformResults &transformResults,
715 auto transformOp = cast<TransformOpInterface>(getOperation());
716
717 SmallVector<int64_t> tileSizes;
719 state, transformOp, getMixedTileSizes(), tileSizes);
720 if (!status.succeeded())
721 return status;
722 SmallVector<int64_t> tileInterchange;
724 state, transformOp, getMixedTileInterchange(), tileInterchange);
725 if (!status.succeeded())
726 return status;
727
728 scf::SCFTilingOptions tilingOptions;
729 tilingOptions.interchangeVector = tileInterchange;
730 bool useForall = getUseForall();
731 tilingOptions.setLoopType(useForall
732 ? scf::SCFTilingOptions::LoopType::ForallOp
733 : scf::SCFTilingOptions::LoopType::ForOp);
734 SmallVector<OpFoldResult> tileSizesOfr =
735 getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
736 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
737 scf::SCFTileAndFuseOptions tileAndFuseOptions;
738 tileAndFuseOptions.tilingOptions = tilingOptions;
739
740 if (getApplyCleanup()) {
741 MLIRContext *context = rewriter.getContext();
743 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
746 tileAndFuseOptions.cleanupPatterns = std::move(patterns);
747 }
748
749 size_t numLoops =
750 useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
751 LogicalResult result = applyTilingToAll(
752 rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
753 transformResults,
754 [&](TilingInterface tilingInterfaceOp)
755 -> FailureOr<scf::SCFTileAndFuseResult> {
756 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
757 tileAndFuseOptions);
758 });
761}
762
763LogicalResult transform::FuseOp::verify() {
764 auto iterspace_rank = getStaticTileSizes().size();
765 ArrayRef<int64_t> permutation = getStaticTileInterchange();
766 if (permutation.size() > iterspace_rank)
767 return emitOpError()
768 << "interchange length exceeds iteration space dimensions ("
769 << iterspace_rank << "), found " << getTileInterchange();
770 SmallVector<bool> seen(iterspace_rank, false);
771 for (int64_t v : permutation) {
772 if (!ShapedType::isDynamic(v)) {
773 if (v < 0 || v >= static_cast<int64_t>(iterspace_rank))
774 return emitOpError() << "expects interchange values to be in range [0, "
775 << iterspace_rank << "), found: " << v;
776 if (seen[v])
777 return emitOpError() << "found duplicate interchange value: " << v;
778 seen[v] = true;
779 }
780 }
781
782 ArrayRef<int64_t> sizes = getStaticTileSizes();
783 size_t numExpectedLoops =
784 getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
785 if (numExpectedLoops != getNumResults() - 1)
786 return emitOpError() << "expects " << numExpectedLoops << " loop results";
787
788 return success();
789}
790
791SmallVector<OpFoldResult> transform::FuseOp::getMixedTileSizes() {
792 return getMixedValues(getStaticTileSizes(), getTileSizes(), getContext());
793}
794
795SmallVector<OpFoldResult> transform::FuseOp::getMixedTileInterchange() {
796 return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
797 getContext());
798}
799
800void transform::FuseOp::getEffects(
802 consumesHandle(getTargetMutable(), effects);
803 onlyReadsHandle(getTileSizesMutable(), effects);
804 onlyReadsHandle(getTileInterchangeMutable(), effects);
805 producesHandle(getOperation()->getOpResults(), effects);
806 modifiesPayload(effects);
807}
808
809//===----------------------------------------------------------------------===//
810// FuseIntoContainingOp
811//===----------------------------------------------------------------------===//
812
813void transform::FuseIntoContainingOp::build(OpBuilder &builder,
815 Value producerOp,
816 Value containingOp) {
817 result.addOperands({producerOp, containingOp});
818 auto resultType = transform::AnyOpType::get(builder.getContext());
819 result.addTypes({resultType, resultType});
820}
821
822/// Add new operands to the forall op for users of the producerOp
823/// that are dominated by the containing scf.forall op.
825 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
826 Operation *containingOp, TilingResult &tileAndFuseResult,
827 int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
829
830 // Count number of users not including the containing op
831 SetVector<Operation *> dominatedUsers;
832 DominanceInfo domInfo(containingOp);
833 for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
834 if (!containingOp->isAncestor(user) &&
835 (domInfo.dominates(containingOp, user))) {
836 dominatedUsers.insert(user);
837 }
838 }
839 if (dominatedUsers.empty())
840 return nullptr;
841
842 // Create new scf.forall op
843 auto forallOp = cast<scf::ForallOp>(containingOp);
844 OpBuilder::InsertionGuard g(rewriter);
845 rewriter.setInsertionPoint(forallOp);
846
847 // Get new output
848 Location loc = forallOp.getLoc();
849 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
850 if (!genericOp)
851 return nullptr;
852 SmallVector<Value> outputs = genericOp.getOutputs();
853 SmallVector<Value> newOuts(forallOp.getOutputs());
854 newOuts.push_back(outputs[resultNumber]);
855
856 // Create new scf.forall op
857 auto newforallOp = scf::ForallOp::create(
858 rewriter, loc, forallOp.getMixedLowerBound(),
859 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
860 forallOp.getMapping());
861 rewriter.eraseBlock(newforallOp.getBody());
862 newforallOp.getRegion().takeBody(forallOp.getRegion());
863
864 // Add additional block argument for new value being returned
865 // and replaces all uses of the new output with corresponding bbArg
866 // inside the scf.forall to enable fusion into this new scf.forall.
867 newforallOp.getBody()->addArgument(newOuts.back().getType(),
868 newOuts.back().getLoc());
869 auto bbArgs = newforallOp.getBody()->getArguments();
870 rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
871 [&](OpOperand &use) {
872 Operation *op = use.getOwner();
873 return newforallOp->isProperAncestor(op);
874 });
875
876 // Fix terminator
877 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
878 SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
879 terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
880 Operation *firstYieldOp = yieldingOps.front();
881 rewriter.setInsertionPoint(firstYieldOp);
882 Value src = tileAndFuseResult.tiledValues[0];
883 Value dst = newforallOp.getRegionIterArgs().back();
884 SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
885 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->getLoc(), src,
886 dst, offsets, sizes, strides);
887
888 for (auto result : llvm::enumerate(forallOp.getResults())) {
889 rewriter.replaceAllUsesWith(result.value(),
890 newforallOp->getResult(result.index()));
891 }
892 rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
893 newforallOp->getResults().back(),
894 [&](OpOperand &use) {
895 Operation *user = use.getOwner();
896 return dominatedUsers.contains(user);
897 });
898 return newforallOp;
899}
900
901/// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
902/// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
903/// outer loop. To determine the second condition, this function iterates
904/// using a worklist over the enclosing loops, trying to find 'src' in any of
905/// the parent loop's iter args.
906static bool sameOrEquivalentIterArg(Value src, Value dst) {
907 // Stack like vector containing possible iterArgs candidates. The first one
908 // is dst, and we will transverse the IR from there.
909 SmallVector<Value> destWorklist;
910 destWorklist.push_back(dst);
911
912 while (!destWorklist.empty()) {
913 Value currentDst = destWorklist.pop_back_val();
914
915 // We have found the same operand in some iter arg in the loop structure,
916 // so src and dst are equivalent.
917 if (src == currentDst)
918 return true;
919
920 // The operands are not equivalent, look for enclosing loops over
921 // currentDst.
922 auto bbArg = dyn_cast<BlockArgument>(currentDst);
923 if (!bbArg)
924 continue;
925
926 Block *parentBlock = bbArg.getOwner();
927 assert(parentBlock && "unlinked block argument");
928
929 Operation *parentOp = parentBlock->getParentOp();
930 assert(parentOp && "expected block argument with parent operation");
931
932 // Check if parent is loop-like. If it's not, do not add it to the worklist.
933 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
934 if (!parentLoop)
935 continue;
936
937 for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
938 // No need to check for null as innerIterArg is tied to parentLoop.
939 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
940 Value loopBlockArgument =
941 parentLoop->getOperand(operand->getOperandNumber());
942 destWorklist.push_back(loopBlockArgument);
943 }
944 }
945
946 return false;
947}
948
949/// Find the first "extract" user of `producerOp` and tile it right before its
950/// use. The tiled op is fused under the `containingOp`.
951/// Return this fused op on success or nullptr if anything fails.
952/// If tiled op has uses that are dominated by `containingOp`, return
953/// a new `containingOp` with results of the fused op appended to
954/// results of the `containingOp` or nullptr if there are no dominated uses.
955static std::tuple<SmallVector<Operation *>, Operation *>
957 Operation *producerOp, Operation *containingOp) {
958 LDBG() << "Try to fuse a direct extract use";
959 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
960 if (!tileableProducer) {
961 diag.attachNote(producerOp->getLoc())
962 << "producer is not a TileableInterface: " << *producerOp;
963 return {};
964 }
965
966 // Search the producer slices accessed within the containing operation.
967 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
968 // evolve into an interface.
969 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
970 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
971 return sliceOp && containingOp->isProperAncestor(sliceOp);
972 });
973
974 // Find a fusion opportunity.
975 if (it == tileableProducer->getUsers().end()) {
976 diag.attachNote(tileableProducer->getLoc())
977 << "could not find fusion opportunity for: " << *tileableProducer;
978 return {};
979 }
980 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
981
982 // Try to fuse the producer in-place.
983 OpBuilder::InsertionGuard guard(rewriter);
984 rewriter.setInsertionPoint(sliceOpToTile);
985
986 // Clone the producer inside the consumer and try to update the producer init
987 // operands using the loop bbArgs if applicable. More precisely, if the bbArg
988 // of the container loop points to a value that it is used by the consumer op,
989 // then, instead of using such value on the consumer, use the value coming
990 // from the bbArg instead. This allows to reuse the output tensor (instead of
991 // creating a new one) of the container when both producer and container write
992 // to the same output.
993 if (LoopLikeOpInterface containerLoop =
994 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
995 Operation *clone = rewriter.clone(*producerOp);
996 rewriter.modifyOpInPlace(clone, [&]() {
997 // Iterate over the outputs of the producer and over the loop bbArgs and
998 // check if any bbArg points to the same value as the producer output. In
999 // such case, make the producer output point to the bbArg directly.
1000 auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone);
1001 if (!dpsInterface)
1002 return;
1003
1004 for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
1005 Value producerOperand =
1006 clone->getOperand(initOperandPtr.getOperandNumber());
1007 for (BlockArgument containerIterArg :
1008 containerLoop.getRegionIterArgs()) {
1009 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
1010 Value consumerOperand =
1011 containerLoop->getOperand(bbArg->getOperandNumber());
1012 // The producer has the same init as the loop bbArg, use it.
1013 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
1014 initOperandPtr.set(containerIterArg);
1015 }
1016 }
1017 }
1018 });
1019
1020 tileableProducer = dyn_cast<TilingInterface>(clone);
1021 }
1022
1023 // Tile the producer.
1024 int64_t resultNumber =
1025 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
1026 LDBG() << "resultNumber: " << resultNumber;
1027
1028 SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
1029 SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
1030
1031 FailureOr<TilingResult> tileAndFuseResult =
1032 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
1033 sizes);
1034
1035 if (failed(tileAndFuseResult)) {
1036 diag.attachNote(tileableProducer->getLoc())
1037 << "failed to tile producer op: " << *tileableProducer;
1038 return {};
1039 }
1040
1041#ifndef NDEBUG
1042 for (auto *tiledOp : tileAndFuseResult->tiledOps) {
1043 LDBG() << "tiledProducer: " << *tiledOp;
1044 }
1045#endif
1046
1047 // Replace the extract op.
1048 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1049 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1050 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1051 if (failed(maybeRankReduced)) {
1052 diag.attachNote(producerOp->getLoc())
1053 << "shape types don't match (missing canonicalization?):\nTiledOp: "
1054 << tileAndFuseResult->tiledValues[0]
1055 << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
1056 return {};
1057 }
1058 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
1059
1060 // Add new outputs to containing op, if required
1061 Operation *newContainingOp = replaceForAllWithNewSignature(
1062 rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
1063 resultNumber, offsets, sizes);
1064
1065 // Cleanup clone.
1066 if (isa<LoopLikeOpInterface>(containingOp))
1067 rewriter.eraseOp(tileableProducer);
1068
1069 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
1070}
1071
1072/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
1073/// it is exactly the `containingOp`, otherwise bail.
1074/// Then, find the first "extract" user of the tied block argument and tile it
1075/// right before its "extract" use. The tiled op is fused under the
1076/// `containingOp`.
1077/// Return this fused op on success or nullptr if anything fails.
1080 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
1081 Operation *containingOp) {
1082 LDBG() << "Try to fuse an extract use through block argument";
1083
1084 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
1085 if (!tileableProducer) {
1086 diag.attachNote(producerOp->getLoc())
1087 << "producer is not a TileableInterface: " << *producerOp;
1088 return {};
1089 }
1090
1091 // Search the first use by a "scf::ForallOp" user.
1092 scf::ForallOp forallOp;
1093 auto itProducerUses =
1094 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
1095 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
1096 return forallOp;
1097 });
1098 // If it's not from the containing op, return.
1099 if (!forallOp || forallOp != containingOp) {
1100 diag.attachNote(tileableProducer->getLoc())
1101 << "could not find a use by the containing op: " << *tileableProducer;
1102 return {};
1103 }
1104
1105 // Search the producer slices accessed within the containing
1106 // operation.
1107 // TODO: Generalize to more extract/insert/parallel_insert triples.
1108 // Maybe evolve into an interface.
1109 OpOperand *pUse = &(*itProducerUses);
1110 BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
1111
1112 // Search the producer slices accessed within the containing operation.
1113 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
1114 // evolve into an interface.
1115 auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
1116 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1117 return sliceOp && containingOp->isProperAncestor(sliceOp);
1118 });
1119
1120 // Find a fusion opportunity.
1121 if (itBBArgUsers == bbArg.getUsers().end()) {
1122 diag.attachNote(containingOp->getLoc())
1123 << "could not find fusion opportunity for bbArg: " << bbArg;
1124 return {};
1125 }
1126 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
1127
1128 // Try to fuse the producer in-place.
1129 OpBuilder::InsertionGuard guard(rewriter);
1130 rewriter.setInsertionPoint(sliceOpToTile);
1131
1132 // Replace the use in the tileableProducer before tiling: clone, replace and
1133 // then tile.
1134 int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
1135 LDBG() << "resultNumber: " << resultNumber;
1136
1137 // Gather destination tensors.
1138 SmallVector<Value> destinationTensors;
1140 rewriter, tileableProducer->getLoc(), tileableProducer,
1141 destinationTensors))) {
1142 diag.attachNote(tileableProducer->getLoc())
1143 << "failed to get destination tensors for: " << *tileableProducer;
1144 return {};
1145 }
1146
1147 IRMapping bvm;
1148 bvm.map(destinationTensors[resultNumber], bbArg);
1149 auto tileableProducerClone =
1150 cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
1151 auto scopeGuard =
1152 llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
1153
1154 // Tile the producer.
1155 FailureOr<TilingResult> tileAndFuseResult =
1156 tileableProducerClone.generateResultTileValue(
1157 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
1158 sliceOpToTile.getMixedSizes());
1159 if (failed(tileAndFuseResult)) {
1160 diag.attachNote(tileableProducer->getLoc())
1161 << "failed to tile producer op: " << *tileableProducer;
1162 return {};
1163 }
1164
1165 // Replace the extract op.
1166 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1167 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1168 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1169 assert(succeeded(maybeRankReduced) && "unexpected shape");
1170 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
1171
1172 // Replace the use in containingOp.
1173 rewriter.modifyOpInPlace(containingOp, [&]() {
1174 containingOp->setOperand(pUse->getOperandNumber(),
1175 destinationTensors.front());
1176 });
1177
1178 return tileAndFuseResult->tiledOps;
1179}
1180
1182 Operation *producerOp,
1183 Operation *containingOp) {
1184 LDBG() << "Try to fuse an use by cloning";
1185
1186 // Gather all uses inside the containing op.
1188 for (OpResult result : producerOp->getOpResults()) {
1189 for (OpOperand &use : result.getUses()) {
1190 if (containingOp->isProperAncestor(use.getOwner())) {
1191 uses.push_back(&use);
1192 continue;
1193 }
1194 // Cannot clone and fuse if the use is by the containing op itself: fail
1195 // immediately.
1196 if (containingOp == use.getOwner()) {
1197 diag.attachNote(producerOp->getLoc())
1198 << "producer op use by containing op cannot be fused by cloning";
1199 return nullptr;
1200 }
1201 }
1202 }
1203
1204 // Check for a non-empty list of fusion opportunities.
1205 if (uses.empty()) {
1206 diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
1207 return nullptr;
1208 }
1209
1210 // Clone and fuse inside the containing op.
1211 Operation *fusedOp = nullptr;
1212 OpOperand *use = uses.front();
1213 // Parallel insert slice is not a valid clone destination.
1214 // TODO: Generalize to other type of ops.
1215 assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
1216 "Parallel insert slice is not a valid clone destination");
1217 unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
1218 LDBG() << "resultNumber: " << resultNumber;
1219
1220 OpBuilder::InsertionGuard guard(rewriter);
1221 rewriter.setInsertionPoint(use->getOwner());
1222 fusedOp = rewriter.clone(*producerOp);
1223 rewriter.modifyOpInPlace(
1224 use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1225
1226 return fusedOp;
1227}
1228
1229bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1230 // Allow repeated handles since we are fusing everything anyway.
1231 return true;
1232}
1233
1235transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
1238 SmallVector<Operation *> fusedOps;
1239 auto producerOps = state.getPayloadOps(getProducerOp());
1240 auto containingOps = state.getPayloadOps(getContainingOp());
1241 if (!llvm::hasSingleElement(containingOps)) {
1242 return emitDefiniteFailure()
1243 << "requires exactly one containing_op handle (got "
1244 << llvm::range_size(containingOps) << ")";
1245 }
1246 Operation *containingOp = *containingOps.begin();
1247
1248 // If nothing to fuse, propagate success.
1249 if (std::empty(producerOps)) {
1250 results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
1251 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1253 }
1254
1255 // Helper function to find the next producer that should be fused. Take any
1256 // producer that has a use inside the containing op.
1257 SetVector<Operation *> remainingProducers(llvm::from_range, producerOps);
1258 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1259 for (const auto &it : enumerate(remainingProducers)) {
1260 Operation *producerOp = it.value();
1261 // The containing op may be a user of producerOp: use isAncestor.
1262 int64_t numUsesInContainingOp =
1263 llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
1264 return containingOp->isAncestor(op);
1265 });
1266 // TODO: When resolving the TODO below (no duplicate ops), take an op
1267 // that has no use among the remaining producers. This is a topological
1268 // sorting.
1269 if (numUsesInContainingOp > 0) {
1270 if (numUsesInContainingOp == 1)
1271 remainingProducers.erase(remainingProducers.begin() + it.index());
1272 return producerOp;
1273 }
1274 }
1275 return failure();
1276 };
1277
1278 while (!remainingProducers.empty()) {
1279 auto nextProducer = getNextProducer();
1280 if (failed(nextProducer)) {
1281 auto diag = mlir::emitSilenceableFailure(getLoc())
1282 << "could not find next producer to fuse into container";
1283 diag.attachNote(containingOp->getLoc()) << "containing op";
1284 return diag;
1285 }
1286
1287 Operation *producerOp = *nextProducer;
1288
1289 // Default diagnostic, to be complemented with more failure information.
1291 diag << "could not fuse " << *producerOp << " into " << *containingOp;
1292
1293 // TODO: If there are multiple uses of the producer in the containing op,
1294 // we currently tile/clone the op multiple times (once per use). In some
1295 // cases, we can tile/clone once and reuse the value for each use.
1296 // Futhermore, producers should then be traversed according to a
1297 // topological sorting.
1298 auto [tiledOps, newContainingOp] =
1299 tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1300 if (!tiledOps.empty()) {
1301 LDBG() << "\nFused a direct extract use\n" << *containingOp;
1302 fusedOps.append(tiledOps);
1303 if (newContainingOp) {
1304 // Update handles associated with the containing op so we don't need to
1305 // invalidate them. This is a hack to support better composability
1306 // between tiling and fusion while a proper mechanism is being
1307 // investigated.
1308 //
1309 // DO NOT replicate this elsewhere unless you understand what you are
1310 // doing.
1311 LogicalResult replacementStatus =
1312 rewriter.notifyPayloadOperationReplaced(containingOp,
1313 newContainingOp);
1314 (void)replacementStatus;
1315 assert(succeeded(replacementStatus) &&
1316 "unable to update transform state mapping");
1317 rewriter.eraseOp(containingOp);
1318 containingOp = newContainingOp;
1319 }
1320 continue;
1321 }
1322
1323 SmallVector<Operation *> tiledContainingOpOperand =
1325 rewriter, diag, producerOp, containingOp);
1326 if (!tiledContainingOpOperand.empty()) {
1327 LDBG() << "\nFused an extract use through block argument\n"
1328 << *containingOp;
1329 fusedOps.append(tiledContainingOpOperand);
1330 continue;
1331 }
1332
1333 Operation *cloned =
1334 cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1335 if (cloned) {
1336 LDBG() << "\nFused an use by cloning\n" << *containingOp;
1337 fusedOps.push_back(cloned);
1338 continue;
1339 }
1341 }
1342
1343 results.set(cast<OpResult>(getFusedOp()), fusedOps);
1344 results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1346}
1347
1348void transform::FuseIntoContainingOp::getEffects(
1350 consumesHandle(getProducerOpMutable(), effects);
1351 onlyReadsHandle(getContainingOpMutable(), effects);
1352 producesHandle(getOperation()->getOpResults(), effects);
1353 modifiesPayload(effects);
1354}
1355
1356//===----------------------------------------------------------------------===//
1357// GeneralizeOp
1358//===----------------------------------------------------------------------===//
1359
1361transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1362 LinalgOp target,
1365 // Exit early if no transformation is needed.
1366 if (isa<GenericOp>(target)) {
1367 results.push_back(target);
1369 }
1370 rewriter.setInsertionPoint(target);
1371 FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1372 if (succeeded(generic)) {
1373 results.push_back(generic->getOperation());
1375 }
1376 return emitDefaultSilenceableFailure(target);
1377}
1378
1379//===----------------------------------------------------------------------===//
1380// SpecializeOp
1381//===----------------------------------------------------------------------===/
1382
1384transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1385 LinalgOp target,
1388 // Exit early if the operation is not a generic.
1389 if (!isa<GenericOp>(target)) {
1390 results.push_back(target);
1392 }
1393 rewriter.setInsertionPoint(target);
1394 FailureOr<LinalgOp> named =
1395 specializeGenericOp(rewriter, cast<GenericOp>(target));
1396 if (succeeded(named)) {
1397 results.push_back(named->getOperation());
1399 }
1400 return emitDefaultSilenceableFailure(target);
1401}
1402
1403//===----------------------------------------------------------------------===//
1404// InterchangeOp
1405//===----------------------------------------------------------------------===//
1406
1408transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1409 GenericOp target,
1412 ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1413 // Exit early if no transformation is needed.
1414 if (interchangeVector.empty()) {
1415 results.push_back(target);
1417 }
1418
1419 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1420 if (interchangeVector.size() != numLoops) {
1421 return emitSilenceableError()
1422 << getIteratorInterchangeAttrName() << " has length ("
1423 << interchangeVector.size()
1424 << ") different from the number of loops in the target operation ("
1425 << numLoops << ")";
1426 }
1427 FailureOr<GenericOp> res = interchangeGenericOp(
1428 rewriter, target, SmallVector<unsigned>(interchangeVector));
1429 if (failed(res))
1430 return emitDefiniteFailure() << "failed to apply";
1431 results.push_back(res->getOperation());
1433}
1434
1435LogicalResult transform::InterchangeOp::verify() {
1436 ArrayRef<int64_t> permutation = getIteratorInterchange();
1437 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1438 if (!std::is_permutation(sequence.begin(), sequence.end(),
1439 permutation.begin(), permutation.end())) {
1440 return emitOpError()
1441 << "expects iterator_interchange to be a permutation, found "
1442 << getIteratorInterchange();
1443 }
1444 return success();
1445}
1446
1447//===----------------------------------------------------------------------===//
1448// LinalgCopyToMemrefOp
1449//===----------------------------------------------------------------------===//
1450
1451DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1452 transform::TransformRewriter &rewriter, Operation *targetOp,
1455
1456 // Check if the target can be converted.
1457 if (!isa<linalg::CopyOp>(targetOp)) {
1459 emitSilenceableError() << "only linalg.copy target ops are supported";
1460 diag.attachNote(targetOp->getLoc()) << "target op";
1461 return diag;
1462 }
1463
1464 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1465 if (!copyOp.hasPureBufferSemantics()) {
1467 emitSilenceableError()
1468 << "cannot transform a linalg.copy on tensors into a memref.copy";
1469 diag.attachNote(targetOp->getLoc()) << "target op";
1470 return diag;
1471 }
1472
1473 SmallVector<Value> inputs = copyOp.getInputs();
1474 SmallVector<Value> outputs = copyOp.getOutputs();
1475 assert(inputs.size() == 1 && "expected linalg copy op with one input");
1476 assert(outputs.size() == 1 && "expected memref copy op with one output");
1477 Value input = inputs.front();
1478 Value output = outputs.front();
1479
1480 // linalg.copy supports different element types on source/dest whereas
1481 // memref.copy does not, so we must check that the source and dest types can
1482 // be handled by memref.copy and otherwise reject the transformation.
1483 if (!isa<ShapedType>(input.getType())) {
1485 emitSilenceableError()
1486 << "cannot transform a linalg.copy which input has no shape";
1487 diag.attachNote(targetOp->getLoc()) << "target op";
1488 return diag;
1489 }
1490
1491 // linalg.copy destination must be a shaped type.
1492 assert(isa<ShapedType>(output.getType()));
1493
1494 if (cast<ShapedType>(input.getType()).getElementType() !=
1495 cast<ShapedType>(output.getType()).getElementType()) {
1497 emitSilenceableError()
1498 << "cannot transform a linalg.copy with different source and "
1499 "destination element types ";
1500 diag.attachNote(targetOp->getLoc()) << "target op";
1501 return diag;
1502 }
1503
1504 // Target can be converted, do it.
1505 auto memrefCopyOp =
1506 rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1507
1508 results.push_back(memrefCopyOp);
1510}
1511
1512//===----------------------------------------------------------------------===//
1513// LowerPackOp
1514//===----------------------------------------------------------------------===//
1515
1516DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1517 transform::TransformRewriter &rewriter, linalg::PackOp target,
1518 transform::ApplyToEachResultList &transformResults,
1520 rewriter.setInsertionPoint(target);
1521 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1522 FailureOr<LowerPackResult> res =
1523 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1524 if (failed(res)) {
1525 return mlir::emitSilenceableFailure(target->getLoc())
1526 << "cannot lower to pad + expand + transpose";
1527 }
1528 transformResults.push_back(res->padOp);
1529 transformResults.push_back(res->expandShapeOp);
1530 transformResults.push_back(res->transposeOp);
1532}
1533
1534//===----------------------------------------------------------------------===//
1535// LowerUnPackOp
1536//===----------------------------------------------------------------------===//
1537
1538DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1539 transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1540 transform::ApplyToEachResultList &transformResults,
1542 rewriter.setInsertionPoint(target);
1543 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1544 FailureOr<LowerUnPackOpResult> res =
1545 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1546 if (failed(res)) {
1548 emitSilenceableError()
1549 << "cannot lower to transpose + collapse + extract";
1550 diag.attachNote(target->getLoc()) << "target payload op";
1551 return diag;
1552 }
1553 transformResults.push_back(res->emptyOp);
1554 transformResults.push_back(res->transposeOp);
1555 transformResults.push_back(res->collapseShapeOp);
1556 transformResults.push_back(res->extractSliceOp);
1558}
1559
1560//===---------------------------------------------------------------------===//
1561// MatchOp
1562//===---------------------------------------------------------------------===//
1563
1564void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1565 Value target, ArrayRef<StringRef> opNames) {
1566 result.addOperands(target);
1567 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1568 builder.getStrArrayAttr(opNames));
1569 result.addTypes(transform::AnyOpType::get(builder.getContext()));
1570}
1571
1572void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1573 TypeRange resultTypes, Value target,
1574 ArrayRef<StringRef> opNames) {
1575 result.addOperands(target);
1576 result.addAttribute(MatchOp::getOpsAttrName(result.name),
1577 builder.getStrArrayAttr(opNames));
1578 result.addTypes(resultTypes);
1579}
1580
1582transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1585 llvm::StringSet<> strs;
1586 if (getOps().has_value())
1587 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1588
1589 auto payloadOps = state.getPayloadOps(getTarget());
1590 if (!llvm::hasSingleElement(payloadOps)) {
1591 return emitDefiniteFailure("requires exactly one target handle");
1592 }
1593
1595 bool incorrectNumOperandTypes = false;
1596 auto matchFun = [&](Operation *op) {
1597 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1598 return;
1599
1600 // Interfaces cannot be matched by name, just by ID.
1601 // So we specifically encode the interfaces we care about for this op.
1602 if (getInterface().has_value()) {
1603 auto iface = getInterface().value();
1604 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1605 !isa<LinalgOp>(op))
1606 return;
1607 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1608 !isa<TilingInterface>(op))
1609 return;
1610 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1611 !isa<LoopLikeOpInterface>(op))
1612 return;
1613 }
1614
1615 // Check if all specified attributes match.
1616 if (getOpAttrs().has_value()) {
1617 DictionaryAttr opAttrs = getOpAttrs().value();
1618 for (NamedAttribute attr : opAttrs) {
1619 if (attr.getName() == getInterfaceAttrName() ||
1620 attr.getName() == getOpsAttrName())
1621 continue;
1622 if (!op->hasAttr(attr.getName()))
1623 return;
1624 if (op->getAttr(attr.getName()) != attr.getValue())
1625 return;
1626 }
1627 }
1628
1629 if (getFilterResultType().has_value()) {
1630 Type t = getFilterResultType().value();
1631 if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1632 return;
1633 }
1634
1635 if (getFilterOperandTypes().has_value()) {
1636 mlir::ArrayAttr types = getFilterOperandTypes().value();
1637 auto operandTypes = op->getOperandTypes();
1638
1639 if (types.size() == 1) {
1640 // All the operands must must be equal to the specified type
1641 auto typeattr =
1642 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1643 Type t = cast<::mlir::Type>(typeattr.getValue());
1644 if (!llvm::all_of(op->getOperandTypes(),
1645 [&](Type operandType) { return operandType == t; }))
1646 return;
1647 } else {
1648 // The operand types must match all the types in the list (in the same
1649 // order in with they are specified)
1650 if (types.size() != operandTypes.size()) {
1651 incorrectNumOperandTypes = true;
1652 return;
1653 }
1654
1655 for (auto [attr, operandType] :
1656 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1657 auto typeattr = cast<mlir::TypeAttr>(attr);
1658 Type type = cast<::mlir::Type>(typeattr.getValue());
1659
1660 if (type != operandType)
1661 return;
1662 }
1663 }
1664 }
1665
1666 // All constraints are satisfied.
1667 res.push_back(op);
1668 return;
1669 };
1670
1671 (*payloadOps.begin())->walk(matchFun);
1672 if (incorrectNumOperandTypes)
1673 return emitDefiniteFailure("If filter_operand_types contains more than a "
1674 "type, then it must contain as much types as "
1675 "the number of operands in the target ops");
1676 results.set(cast<OpResult>(getResult()), res);
1678}
1679
1680//===---------------------------------------------------------------------===//
1681// MultiTileSizesOp
1682//===---------------------------------------------------------------------===//
1683
1685 Type targetType, Type lowSizeType, Type,
1686 Type) {
1687 printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1688}
1689
1690static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1691 Type &targetType, Type &lowSizeType,
1692 Type &highSizeType,
1693 Type &splitPointType) {
1694 FunctionType funcType;
1695 llvm::SMLoc typeLoc = parser.getCurrentLocation();
1696 if (failed(parser.parseType<FunctionType>(funcType)))
1697 return failure();
1698
1699 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1700 parser.emitError(typeLoc) << "expects a trailing functional type with one "
1701 "argument and one result";
1702 }
1703 targetType = funcType.getInput(0);
1704 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1705
1706 return success();
1707}
1708
1709DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1710 transform::TransformRewriter &rewriter, LinalgOp target,
1712 if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1713 if (target.hasDynamicShape()) {
1714 auto diag = emitSilenceableError()
1715 << "cannot compute parametric tile sizes for dynamically "
1716 "shaped payload op";
1717 diag.attachNote(target->getLoc()) << "payload op";
1718 return diag;
1719 }
1720
1721 FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1722 target, getDimension(), getTargetSize(), getDivisor());
1723 if (failed(spec)) {
1724 return emitSilenceableError()
1725 << "failed to compute multi-size tiling sizes";
1726 }
1727
1728 Builder builder(target.getContext());
1729 results.assign(llvm::map_range(
1730 ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1731 spec->lowTileSize * spec->lowTripCount}),
1732 [&builder, this](int64_t value) {
1733 return builder.getIntegerAttr(
1734 cast<ParamType>(getLowSize().getType()).getType(), value);
1735 }));
1737 }
1738
1739 OpBuilder builder(target.getContext());
1740 builder.setInsertionPoint(target);
1741 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1742 OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1743 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1744 builder, target, getDimension(), targetSize, divisor);
1745 if (failed(spec)) {
1746 return emitSilenceableError() << "could not generate tile size computation";
1747 }
1748
1749 AffineExpr s0 = builder.getAffineSymbolExpr(0);
1750 AffineExpr s1 = builder.getAffineSymbolExpr(1);
1751 Operation *splitPoint =
1752 affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1753 {spec->lowTileSize, spec->lowTripCount});
1754 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1755 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1756 assert(lowTileSize && highTileSize && splitPoint &&
1757 "tile sizes are not produced by operations");
1758 results.reserve(results.size() + 3);
1759 results.push_back(lowTileSize);
1760 results.push_back(highTileSize);
1761 results.push_back(splitPoint);
1763}
1764
1765void transform::MultiTileSizesOp::getEffects(
1767 onlyReadsHandle(getTargetMutable(), effects);
1768 producesHandle(getOperation()->getOpResults(), effects);
1769 if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1770 onlyReadsPayload(effects);
1771 else
1772 modifiesPayload(effects);
1773}
1774
1775LogicalResult transform::MultiTileSizesOp::verify() {
1776 if (getLowSize().getType() != getHighSize().getType() ||
1777 getLowSize().getType() != getSplitPoint().getType()) {
1778 return emitOpError() << "expects all results type to be the same";
1779 }
1780 return success();
1781}
1782
1783//===---------------------------------------------------------------------===//
1784// PackOp
1785//===---------------------------------------------------------------------===//
1786
1787void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1788 Value target,
1789 ArrayRef<OpFoldResult> mixedPackedSizes) {
1790 SmallVector<int64_t> staticPackedSizes;
1791 SmallVector<Value> dynamicPackedSizes;
1792 dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1793 staticPackedSizes);
1794 // Call the default builder which sets up the proper operands segment sizes
1795 // attributes for multiple variadic operands. In the absence of this, horrible
1796 // bugs ensue.
1797 Type linalgOpHType = transform::OperationType::get(
1798 builder.getContext(), GenericOp::getOperationName());
1799 build(builder, result,
1800 /*resultType=*/linalgOpHType,
1801 /*target=*/target,
1802 /*dynamic_sizes=*/dynamicPackedSizes,
1803 /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1804}
1805
1806SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1807 Builder b(getContext());
1808 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1809}
1810
1812transform::PackOp::apply(transform::TransformRewriter &rewriter,
1813 transform::TransformResults &transformResults,
1815 auto targetOps = state.getPayloadOps(getTarget());
1816 // If nothing to pack, propagate success.
1817 if (std::empty(targetOps)) {
1818 transformResults.set(cast<OpResult>(getPackedOp()),
1821 }
1822 // Fail on multi-op handles.
1823 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1824 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1825 return emitSilenceableError()
1826 << "requires target to map to exactly 1 LinalgOp (got "
1827 << llvm::range_size(targetOps) << ")";
1828 }
1829 // Fail on mismatched number of pack sizes.
1830 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1831 return emitSilenceableError()
1832 << "requires number of packed sizes match the number of loops ("
1833 << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1834 << ")";
1835 }
1836
1837 // Unpack handles to constants or actual SSA index values.
1838 SmallVector<OpFoldResult> packedSizes;
1840 state, *this, packedSizes, getMixedPackedSizes());
1841
1842 rewriter.setInsertionPoint(linalgOp);
1843 FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1844 if (failed(maybeResult))
1845 return emitDefiniteFailure("data tiling failed");
1846
1847 transformResults.set(cast<OpResult>(getPackedOp()),
1848 {maybeResult->packedLinalgOp.getOperation()});
1850}
1851
1852void transform::PackOp::getEffects(
1854 transform::consumesHandle(getTargetMutable(), effects);
1855 transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1856 transform::producesHandle(getOperation()->getOpResults(), effects);
1858}
1859
1860//===---------------------------------------------------------------------===//
1861// PackGreedilyOp.
1862//===---------------------------------------------------------------------===//
1863
1864LogicalResult transform::PackGreedilyOp::verify() {
1865 if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1866 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1867 << " is not a valid permutation";
1868 }
1869 // TODO: relax to allow empty once we have another strategy than just matmul.
1870 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1871 for (auto [s, nmo] :
1872 llvm::zip_equal(getMixedMatmulPackedSizes(),
1873 getMatmulPaddedSizesNextMultipleOf())) {
1874 std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1875 if (nmo != 0 &&
1876 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1877 return emitOpError() << "at most one of the packed_size and the "
1878 "padded_sizes_next_multiple_of can be nonzero "
1879 "for the matmul strategy";
1880 }
1881 }
1882 }
1883 return success();
1884}
1885
1887PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1888 transform::TransformResults &transformResults,
1891 for (Operation *op : state.getPayloadOps(getTarget())) {
1892 auto linalgOp = dyn_cast<LinalgOp>(op);
1893 if (!linalgOp)
1894 continue;
1895 // linalgOp will be replaced and the insertion point may be invalidated if
1896 // we set it before -> set it after.
1897 rewriter.setInsertionPointAfter(linalgOp);
1898 // Failing to pack greedily is perfectly fine.
1899 // In the future we will want to order packings according to some metric.
1900 FailureOr<PackResult> packResult = packMatmulGreedily(
1901 /*rewriter=*/rewriter,
1902 /*linalgOp=*/linalgOp,
1903 /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1904 /*mnkPaddedSizesNextMultipleOf=*/
1905 getMatmulPaddedSizesNextMultipleOf(),
1906 /*mnkOrder=*/getMatmulInnerDimsOrder());
1907 if (succeeded(packResult)) {
1908 results.push_back(packResult->packedLinalgOp);
1909 continue;
1910 }
1911 results.push_back(linalgOp);
1912 }
1913 transformResults.set(cast<OpResult>(getPackedOp()), results);
1915}
1916
1917SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1918 Builder b(getContext());
1919 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1920 b);
1921}
1922
1923void transform::PackGreedilyOp::getEffects(
1925 transform::consumesHandle(getTargetMutable(), effects);
1926 transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1927 transform::producesHandle(getOperation()->getOpResults(), effects);
1929}
1930
1931//===---------------------------------------------------------------------===//
1932// PackTransposeOp
1933//===---------------------------------------------------------------------===//
1934
1935LogicalResult transform::PackTransposeOp::verify() {
1936 if (!isPermutationVector(getInnerPerm())) {
1937 return emitOpError() << getInnerPermAttrName()
1938 << " is not a valid permutation";
1939 }
1940 if (!isPermutationVector(getOuterPerm())) {
1941 return emitOpError() << getOuterPermAttrName()
1942 << " is not a valid permutation";
1943 }
1944 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1945 return emitOpError() << " at least one of " << getInnerPermAttrName()
1946 << " or " << getOuterPermAttrName()
1947 << " must be specified";
1948 }
1949 return success();
1950}
1951
1952namespace {
1953enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1954} // namespace
1955
1956/// Return true if `permutation` is a valid permutation of the
1957/// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1958/// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1959/// This is the case when the `permutation` rank matches the rank expected by
1960/// `op` and `permutation` is itself a permutation vector.
1961/// Return true if either `op` or `permutation` are empty to allow a simpler
1962/// polymorphic implementation.
1963template <typename RelayoutOpTy>
1964static bool isValidPackingPermutation(
1965 RelayoutOpTy op, ArrayRef<int64_t> permutation,
1966 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1967 static_assert(
1968 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1969 "applies to only pack or unpack operations");
1970 if (!op || permutation.empty())
1971 return true;
1972 size_t innerRank = op.getInnerDimsPos().size();
1973 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1974 return permutation.size() == innerRank && isPermutationVector(permutation);
1975 // op.getOuterDimsPerm() may be empty, in which case it is identity.
1976 // Don't rely on it.
1977 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1978 return permutation.size() == op.getSourceRank() &&
1979 isPermutationVector(permutation);
1980 }
1981 return permutation.size() == op.getDestRank() &&
1982 isPermutationVector(permutation);
1983}
1984
1986transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1987 transform::TransformResults &transformResults,
1989 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1990 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1991 // Step 1. If nothing to pack, propagate success.
1992 if (std::empty(packOrUnpackOps)) {
1993 transformResults.set(cast<OpResult>(getPackedOp()), {});
1994 transformResults.set(cast<OpResult>(getPackOp()), {});
1995 transformResults.set(cast<OpResult>(getUnPackOp()), {});
1997 }
1998
1999 // Step 2. Bunch of runtime sanity check and error messages.
2000 // Step 2.1. Fail on multi-op handles.
2001 if (!llvm::hasSingleElement(packOrUnpackOps) ||
2002 !llvm::hasSingleElement(linalgOps)) {
2003 return emitSilenceableError()
2004 << "requires target to map to exactly 1 "
2005 "packing op and 1 packed op ("
2006 << "got " << llvm::range_size(packOrUnpackOps) << " and "
2007 << llvm::range_size(linalgOps) << ")";
2008 }
2009
2010 // Step 2.2. Fail on wrong type.
2011 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
2012 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
2013 if ((!packOp && !unPackOp)) {
2014 return emitSilenceableError() << "requires target to map to a "
2015 "linalg.pack or linalg.unpack";
2016 }
2017 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
2018 if (!linalgOpTarget)
2019 return emitSilenceableError() << "requires a LinalgOp target";
2020
2021 // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
2022 LinalgOp linalgOp;
2023 if (packOp && packOp.getResult().hasOneUse())
2024 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
2025 else if (unPackOp)
2026 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
2027 if (linalgOp != linalgOpTarget) {
2028 auto errorMsg =
2029 packOp ? StringLiteral{"not a single use by the LinalgOp target"}
2030 : StringLiteral{"not produced by the LinalgOp target"};
2031 return emitSilenceableError() << errorMsg;
2032 }
2033
2034 // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
2035 // PackOp.
2036 if (unPackOp) {
2037 assert(!packOp && "packOp must be null on entry when unPackOp is not null");
2038 OpOperand *packUse = linalgOp.getDpsInitOperand(
2039 cast<OpResult>(unPackOp.getSource()).getResultNumber());
2040 packOp = packUse->get().getDefiningOp<linalg::PackOp>();
2041 if (!packOp || !packOp.getResult().hasOneUse())
2042 return emitSilenceableError() << "could not find matching pack op";
2043 }
2044
2045 // Step 2.5. Fail if any permutation does not validate.
2046 for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
2047 ArrayRef<int64_t> perm =
2048 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
2049 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
2050 ? StringLiteral{"invalid outer_perm"}
2051 : StringLiteral{"invalid inner_perm"};
2052 if (!isValidPackingPermutation(packOp, perm, permType) ||
2053 !isValidPackingPermutation(unPackOp, perm, permType)) {
2054 Operation *packOrUnpackOp =
2055 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
2056 return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
2057 }
2058 }
2059
2060 // From here on, packOp and linalgOp are always present, unPackOp may or may
2061 // not be present.
2062 assert(packOp && linalgOp && "unexpected null op");
2063
2064 // Step 3. Actually transpose the ops.
2065 FailureOr<PackTransposeResult> res = packTranspose(
2066 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
2067 // Preconditions have been checked, it is an error to fail here.
2068 assert(succeeded(res) && "unexpected packTranspose failure");
2069
2070 // Step 4. Return results.
2071 transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
2072 transformResults.set(cast<OpResult>(getPackedOp()),
2073 {res->transposedLinalgOp});
2074 if (unPackOp) {
2075 transformResults.set(cast<OpResult>(getUnPackOp()),
2076 {res->transposedUnPackOp});
2077 } else {
2078 transformResults.set(cast<OpResult>(getUnPackOp()), {});
2079 }
2080
2082}
2083
2084//===---------------------------------------------------------------------===//
2085// PadOp
2086//===---------------------------------------------------------------------===//
2087
2088void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
2089 ArrayRef<int64_t> paddingDimensions,
2090 ArrayRef<int64_t> padToMultipleOf,
2091 ArrayRef<int64_t> nofoldFlags,
2092 ArrayRef<Attribute> transposePaddings,
2093 StringRef copyBackOp,
2094 bool usePrescribedTensorShapes) {
2095 auto resultType = transform::AnyOpType::get(b.getContext());
2096 return build(/*odsBuilder=*/b,
2097 /*result=*/result,
2098 /*types=*/TypeRange{resultType, resultType},
2099 /*target=*/target,
2100 /*padding_values=*/ArrayAttr(), // let inference handle this
2101 /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
2102 /*pad_to_multiple_of=*/ValueRange{},
2103 /*padToMultipleOf=*/
2104 (padToMultipleOf.empty()
2106 : b.getDenseI64ArrayAttr(padToMultipleOf)),
2107 /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
2108 /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
2109 /*copy_back_op=*/b.getStringAttr(copyBackOp),
2110 /*use_prescribed_tensor_shapes=*/
2111 usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
2112}
2113
2114void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
2115 ArrayRef<int64_t> paddingDimensions,
2116 ArrayRef<OpFoldResult> mixedPadToMultipleOf,
2117 ArrayRef<int64_t> nofoldFlags,
2118 ArrayRef<Attribute> transposePaddings,
2119 StringRef copyBackOp,
2120 bool usePrescribedTensorShapes) {
2121 auto resultType = transform::AnyOpType::get(b.getContext());
2122 SmallVector<int64_t> staticPadToMultipleOf;
2123 SmallVector<Value> dynamicPadToMultipleOf;
2124 dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
2125 staticPadToMultipleOf);
2126 return build(/*odsBuilder=*/b,
2127 /*result=*/result,
2128 /*types=*/TypeRange{resultType, resultType},
2129 /*target=*/target,
2130 /*padding_values=*/ArrayAttr(), // let inference handle this
2131 /*padding_dimensions=*/b.getI64ArrayAttr(paddingDimensions),
2132 /*pad_to_multiple_of=*/dynamicPadToMultipleOf,
2133 /*padToMultipleOf=*/staticPadToMultipleOf,
2134 /*nofold_flags=*/b.getI64ArrayAttr(nofoldFlags),
2135 /*transpose_paddings=*/b.getArrayAttr(transposePaddings),
2136 /*copy_back_op=*/copyBackOp,
2137 /*use_prescribed_tensor_shapes=*/usePrescribedTensorShapes);
2138}
2139
2140void PadOp::getEffects(
2142 consumesHandle(getTargetMutable(), effects);
2143 onlyReadsHandle(getPadToMultipleOfMutable(), effects);
2144 producesHandle(getOperation()->getOpResults(), effects);
2145 modifiesPayload(effects);
2146}
2147
2148SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
2149 Builder b(getContext());
2150 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
2151}
2152
2153DiagnosedSilenceableFailure
2154transform::PadOp::apply(transform::TransformRewriter &rewriter,
2155 transform::TransformResults &results,
2156 transform::TransformState &state) {
2157 auto transformOp = cast<TransformOpInterface>(getOperation());
2158 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
2159
2160 for (Operation *target : state.getPayloadOps(getTarget())) {
2161 auto linalgTarget = dyn_cast<LinalgOp>(target);
2162 if (!linalgTarget) {
2163 auto diag = emitSilenceableError() << "expected LinalgOp target";
2164 diag.attachNote(target->getLoc()) << "target op";
2165 return diag;
2166 }
2167
2168 // Convert the integer packing flags to booleans.
2169 SmallVector<bool> nofoldFlags;
2170 for (int64_t packPadding :
2171 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
2172 nofoldFlags.push_back(static_cast<bool>(packPadding));
2173
2174 // Convert the padding values to attributes.
2175 SmallVector<Attribute> paddingValues;
2176 for (auto const &[untypedAttr, elementOrTensorType] :
2177 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
2178
2179 if (isa<ub::PoisonAttr>(untypedAttr)) {
2180 paddingValues.push_back(untypedAttr);
2181 continue;
2182 }
2183 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2184 if (!attr) {
2185 emitOpError("expects padding values to be typed attributes or poison");
2187 }
2188 Type elementType = getElementTypeOrSelf(elementOrTensorType);
2189 // Try to parse string attributes to obtain an attribute of element type.
2190 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2191 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2192 stringAttr, getContext(), elementType,
2193 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2194 if (!parsedAttr || parsedAttr.getType() != elementType) {
2195 auto diag = this->emitOpError("expects a padding that parses to ")
2196 << elementType << ", got " << untypedAttr;
2197 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2199 }
2200 paddingValues.push_back(parsedAttr);
2201 continue;
2202 }
2203 // Otherwise, add the attribute directly.
2204 if (attr.getType() != elementType) {
2205 auto diag = this->emitOpError("expects a padding value of type ")
2206 << elementType << ", got " << attr;
2207 diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2209 }
2210 paddingValues.push_back(attr);
2211 }
2212
2213 // Extract the transpose vectors.
2214 SmallVector<SmallVector<int64_t>> transposePaddings;
2215 for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2216 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2217 cast<ArrayAttr>(transposeVector)));
2218
2219 LinalgOp paddedOp;
2220 LinalgPaddingOptions options;
2221 options.paddingDimensions =
2222 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2223
2224 SmallVector<int64_t> padToMultipleOf;
2225 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
2226 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2227 if (!status.succeeded())
2228 return status;
2229 if (padToMultipleOf.empty())
2230 padToMultipleOf =
2231 SmallVector<int64_t>(options.paddingDimensions.size(), 1);
2232
2233 options.padToMultipleOf = padToMultipleOf;
2234 options.paddingValues = paddingValues;
2235 options.nofoldFlags = nofoldFlags;
2236 if (getCopyBackOp() ==
2237 bufferization::MaterializeInDestinationOp::getOperationName()) {
2238 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
2239 BufferizationMaterializeInDestination;
2240 } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2241 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
2242 } else if (getCopyBackOp() == kCopyOpNone) {
2243 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
2244 } else {
2245 llvm_unreachable("unsupported copy_back op");
2246 }
2247 // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2248 bool irChanged = false;
2249 if (getUsePrescribedTensorShapes() &&
2250 linalgTarget.hasPureTensorSemantics()) {
2251 OpBuilder::InsertionGuard g(rewriter);
2252 rewriter.setInsertionPoint(linalgTarget);
2253 for (OpOperand &operand : linalgTarget->getOpOperands()) {
2254 for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2255 if (ShapedType::isStatic(dim))
2256 continue;
2257 options.setSizeToPadTo(operand.getOperandNumber(), i,
2258 tensor::getMixedSize(rewriter,
2259 operand.get().getLoc(),
2260 operand.get(), i));
2261 irChanged = true;
2262 }
2263 }
2264 }
2265
2266 SmallVector<Value> replacements;
2267 SmallVector<tensor::PadOp> newPadOps;
2268 if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
2269 replacements, newPadOps))) {
2270 if (irChanged) {
2271 auto diag = emitDefiniteFailure() << "failed to pad op";
2272 diag.attachNote(target->getLoc()) << "target op";
2273 return diag;
2274 }
2275 auto diag = emitSilenceableError() << "failed to pad op";
2276 diag.attachNote(target->getLoc()) << "target op";
2277 return diag;
2278 }
2279
2280 // We need to perform our own replacement here because this API is still
2281 // used in patterns that "pad and hoist", for which the replacement values
2282 // need to be different.
2283 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2284 // that we have more composable abstractions.
2285 rewriter.replaceOp(linalgTarget, replacements);
2286 paddedOps.push_back(paddedOp);
2287 padOps.append(newPadOps.begin(), newPadOps.end());
2288 if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2289 for (Value v : replacements) {
2290 Operation *copyBackOp = v.getDefiningOp();
2291 if (!llvm::is_contained(copyBackOps, copyBackOp))
2292 copyBackOps.push_back(copyBackOp);
2293 }
2294 }
2295 }
2296
2297 results.set(cast<OpResult>(getPadded()), paddedOps);
2298 results.set(cast<OpResult>(getPad()), padOps);
2299 results.set(cast<OpResult>(getCopy()), copyBackOps);
2301}
2302
2303LogicalResult transform::PadOp::verify() {
2304 SmallVector<int64_t> nofoldFlags =
2305 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2306 if (any_of(nofoldFlags, [](int64_t packPadding) {
2307 return packPadding != 0 && packPadding != 1;
2308 })) {
2309 return emitOpError()
2310 << "expects nofold_flags to contain booleans (0/1), found "
2311 << getNofoldFlags();
2312 }
2313
2314 SmallVector<int64_t> paddingDimensions =
2315 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2316 if (any_of(paddingDimensions,
2317 [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2318 return emitOpError() << "expects padding_dimensions to contain positive "
2319 "integers, found "
2320 << getPaddingDimensions();
2321 }
2322 if (!getMixedPadToMultipleOf().empty()) {
2323 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2324 return emitOpError() << "expects as many multiples as padding_dimensions";
2325 }
2326 }
2327 ArrayAttr transposes = getTransposePaddings();
2328 for (Attribute attr : transposes) {
2329 SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2330 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2331 if (!std::is_permutation(sequence.begin(), sequence.end(),
2332 transpose.begin(), transpose.end())) {
2333 return emitOpError()
2334 << "expects transpose_paddings to be a permutation, found "
2335 << attr;
2336 }
2337 }
2338 if (getCopyBackOp() !=
2339 bufferization::MaterializeInDestinationOp::getOperationName() &&
2340 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2341 getCopyBackOp() != kCopyOpNone)
2342 return emitOpError() << "invalid copy_back_op";
2343 return success();
2344}
2345
2346//===---------------------------------------------------------------------===//
2347// PadTilingInterfaceOp
2348//===---------------------------------------------------------------------===//
2349
2350void transform::PadTilingInterfaceOp::build(OpBuilder &b,
2351 OperationState &result,
2352 Value target,
2353 ArrayRef<int64_t> paddingSizes,
2354 bool padToMultipleOf) {
2355 auto resultType = transform::AnyOpType::get(b.getContext());
2356 return build(/*odsBuilder=*/b,
2357 /*result=*/result,
2358 /*types=*/TypeRange{resultType, resultType},
2359 /*target=*/target,
2360 /*padding_values=*/ArrayAttr(), // let inference handle this
2361 /*padding_sizes=*/ValueRange{},
2362 /*paddingSizes=*/
2363 (paddingSizes.empty() ? DenseI64ArrayAttr()
2364 : b.getDenseI64ArrayAttr(paddingSizes)),
2365 /*pad_to_multiple_of=*/
2366 padToMultipleOf ? b.getUnitAttr() : nullptr);
2367}
2368
2369void transform::PadTilingInterfaceOp::build(
2370 OpBuilder &b, OperationState &result, Value target,
2371 ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
2372 auto resultType = transform::AnyOpType::get(b.getContext());
2373 SmallVector<int64_t> staticPaddingSizes;
2374 SmallVector<Value> dynamicPaddingSizes;
2375 dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
2376 staticPaddingSizes);
2377 return build(/*odsBuilder=*/b,
2378 /*result=*/result,
2379 /*types=*/TypeRange{resultType, resultType},
2380 /*target=*/target,
2381 /*padding_values=*/ArrayAttr(), // let inference handle this
2382 /*padding_sizes=*/dynamicPaddingSizes,
2383 /*paddingSizes=*/staticPaddingSizes,
2384 /*usePrescribedTensorShapes=*/padToMultipleOf);
2385}
2386
2387void transform::PadTilingInterfaceOp::getEffects(
2388 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2389 consumesHandle(getTargetMutable(), effects);
2390 onlyReadsHandle(getPaddingSizesMutable(), effects);
2391 producesHandle(getOperation()->getOpResults(), effects);
2392 modifiesPayload(effects);
2393}
2394
2395SmallVector<OpFoldResult>
2396transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2397 Builder b(getContext());
2398 return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2399}
2400
2401DiagnosedSilenceableFailure
2402transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2403 transform::TransformResults &results,
2404 transform::TransformState &state) {
2405 SmallVector<Operation *> paddedOps, padOps;
2406
2407 for (Operation *target : state.getPayloadOps(getTarget())) {
2408 auto targetOp = dyn_cast<TilingInterface>(target);
2409 if (!targetOp) {
2410 auto diag = emitSilenceableError() << "expected TilingInterface target";
2411 diag.attachNote(target->getLoc()) << "target op";
2412 return diag;
2413 }
2414
2415 // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
2416 // loopsToOperand map / C++ APIs to compute the effect of padding on
2417 // operands.
2418 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2419 auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
2420 "supported atm";
2421 diag.attachNote(target->getLoc()) << "target op";
2422 return diag;
2423 }
2424
2425 // Convert the padding values to attributes.
2426 SmallVector<Attribute> paddingValues;
2427 for (auto const &[untypedAttr, elementOrTensorType] :
2428 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2429 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2430 Type elementType = getElementTypeOrSelf(elementOrTensorType);
2431
2432 if (isa<ub::PoisonAttr>(untypedAttr)) {
2433 paddingValues.push_back(untypedAttr);
2434 continue;
2435 }
2436 if (!attr) {
2437 emitOpError("expects padding values to be typed attributes or poison");
2439 }
2440 // Try to parse string attributes to obtain an attribute of element type.
2441 if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2442 auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2443 stringAttr, getContext(), elementType,
2444 /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2445 if (!parsedAttr || parsedAttr.getType() != elementType) {
2446 auto diag = this->emitOpError("expects a padding that parses to ")
2447 << elementType << ", got " << attr;
2448 diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2450 }
2451 paddingValues.push_back(parsedAttr);
2452 continue;
2453 }
2454 // Otherwise, add the attribute directly.
2455 if (attr.getType() != elementType) {
2456 auto diag = this->emitOpError("expects a padding value of type ")
2457 << elementType << ", got " << attr;
2458 diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2460 }
2461 paddingValues.push_back(attr);
2462 }
2463
2464 // Set options.
2465 PadTilingInterfaceOptions options;
2466 options.setPaddingValues(paddingValues)
2467 .setPaddingSizes(getMixedPaddingSizes())
2468 .setPadToMultipleOf(getPadToMultipleOf());
2469
2470 OpBuilder::InsertionGuard g(rewriter);
2471 rewriter.setInsertionPointAfter(targetOp);
2472 auto maybePadOps = rewriteAsPaddedOp(
2473 rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
2474 if (failed(maybePadOps)) {
2475 auto diag = emitSilenceableError() << "failed to pad op";
2476 diag.attachNote(target->getLoc()) << "target op";
2477 return diag;
2478 }
2479 const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
2480
2481 // Set transform results.
2482 paddedOps.push_back(paddedOp);
2483 padOps.append(paddedOperands.begin(), paddedOperands.end());
2484 rewriter.replaceOp(targetOp.getOperation(), slicedResults);
2485 }
2486
2487 results.set(cast<OpResult>(getPadded()), paddedOps);
2488 results.set(cast<OpResult>(getPad()), padOps);
2490}
2491
2492LogicalResult transform::PadTilingInterfaceOp::verify() { return success(); }
2493
2494//===---------------------------------------------------------------------===//
2495// HoistPadOp
2496//===---------------------------------------------------------------------===//
2497
2498DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2499 transform::TransformRewriter &rewriter,
2500 transform::TransformResults &transformResults,
2501 transform::TransformState &state) {
2502 auto targetOps = state.getPayloadOps(getTarget());
2503 auto loopOps = state.getPayloadOps(getLoop());
2504 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2505 return emitDefiniteFailure()
2506 << "requires exactly one target and one loop handle (got "
2507 << llvm::range_size(targetOps) << " and "
2508 << llvm::range_size(loopOps) << ")";
2509 }
2510
2511 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2512 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2513 if (!padOp || !loopOp)
2514 return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2515
2516 FailureOr<linalg::detail::PackingResult> result =
2517 linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2518 getTranspose());
2519 if (failed(result))
2520 return emitDefiniteFailure() << "could not build packing loop nest";
2521
2522 if (result->clonedLoopIvs.empty()) {
2523 transformResults.set(cast<OpResult>(getPackingLoop()),
2524 {result->hoistedPadOp.getOperation()});
2526 }
2527 auto outerPackedLoop =
2528 scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2529 transformResults.set(cast<OpResult>(getPackingLoop()),
2530 {outerPackedLoop.getOperation()});
2532}
2533
2534LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
2535 ArrayRef<int64_t> transpose = getTranspose();
2536 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2537 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2538 transpose.end())) {
2539 return emitOpError() << "expects transpose to be a permutation, found "
2540 << getTranspose();
2541 }
2542 return success();
2543}
2544
2545void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2546 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2547 transform::onlyReadsHandle(getTargetMutable(), effects);
2548 transform::onlyReadsHandle(getLoopMutable(), effects);
2549 transform::producesHandle(getOperation()->getOpResults(), effects);
2551}
2552
2553DiagnosedSilenceableFailure
2554transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2555 tensor::PadOp target,
2556 transform::ApplyToEachResultList &results,
2557 transform::TransformState &state) {
2558 tensor::PadOp hoistedPadOp;
2559 SmallVector<TransposeOp> transposeOps;
2560 FailureOr<Value> result =
2561 hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2562 hoistedPadOp, transposeOps);
2563 if (succeeded(result)) {
2564 // We need to perform our own replacement here because this API is still
2565 // used in patterns that "pad and hoist", for which the replacement values
2566 // need to be different.
2567 // TODO: clean this up and stop "pad and hoist" behavior more globally now
2568 // that we have more composable abstractions.
2569 rewriter.replaceOp(target, *result);
2570 results.push_back(hoistedPadOp);
2572 }
2573 return emitDefaultSilenceableFailure(target);
2574}
2575
2576LogicalResult transform::HoistPadOp::verify() {
2577 ArrayRef<int64_t> transpose = getTranspose();
2578 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2579 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2580 transpose.end())) {
2581 return emitOpError() << "expects transpose to be a permutation, found "
2582 << getTranspose();
2583 }
2584 return success();
2585}
2586
2587//===----------------------------------------------------------------------===//
2588// PromoteOp
2589//===----------------------------------------------------------------------===//
2590
2591DiagnosedSilenceableFailure
2592transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2593 LinalgOp target,
2594 transform::ApplyToEachResultList &results,
2595 transform::TransformState &state) {
2596 LinalgPromotionOptions promotionOptions;
2597 if (!getOperandsToPromote().empty())
2598 promotionOptions = promotionOptions.setOperandsToPromote(
2599 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2600 if (getUseFullTilesByDefault())
2601 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2602 getUseFullTilesByDefault());
2603 if (getUseOriginalSubviewSize())
2604 promotionOptions =
2605 promotionOptions.setUseOriginalSubviewSize(getUseOriginalSubviewSize());
2606 if (getUseAlloca())
2607 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2608 if (!getUseFullTileBuffers().empty())
2609 promotionOptions = promotionOptions.setUseFullTileBuffers(
2610 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2611 if (getAlignment().has_value())
2612 promotionOptions = promotionOptions.setAlignment(*getAlignment());
2613 if (getMemorySpace().has_value())
2614 promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2615
2616 if (getMapping().has_value()) {
2617 // The mapping should only contain an element
2618 auto mapping = *getMapping();
2619 if (mapping.size() > 1)
2620 return emitDefaultDefiniteFailure(target);
2621
2622 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2623
2624 if (addressSpace.getAddressSpace() ==
2625 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2626 promotionOptions =
2627 promotionOptions
2631 .setUseFullTileBuffers({false, false});
2632 } else if (addressSpace.getAddressSpace() ==
2633 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2634 promotionOptions =
2635 promotionOptions
2639 .setUseFullTileBuffers({false, false});
2640 } else {
2641 return emitDefaultDefiniteFailure(target);
2642 }
2643 }
2644
2645 if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2646 return emitDefaultDefiniteFailure(target);
2647
2648 rewriter.setInsertionPoint(target);
2649 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2650 if (failed(res))
2651 return emitDefaultDefiniteFailure(target);
2652 results.push_back(target);
2654}
2655
2656//===----------------------------------------------------------------------===//
2657// ReplaceOp
2658//===----------------------------------------------------------------------===//
2659
2660DiagnosedSilenceableFailure
2661transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2662 TransformResults &transformResults,
2663 TransformState &state) {
2664 auto payload = state.getPayloadOps(getTarget());
2665
2666 // Check for invalid targets.
2667 for (Operation *target : payload) {
2668 if (target->getNumOperands() > 0)
2669 return emitDefiniteFailure() << "expected target without operands";
2670 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2671 target->getNumRegions() > 0)
2672 return emitDefiniteFailure()
2673 << "expected target that is isolated from above";
2674 }
2675
2676 // Clone and replace.
2677 Operation *pattern = &getBodyRegion().front().front();
2678 SmallVector<Operation *> replacements;
2679 for (Operation *target : payload) {
2680 if (getOperation()->isAncestor(target))
2681 continue;
2682 rewriter.setInsertionPoint(target);
2683 Operation *replacement = rewriter.clone(*pattern);
2684 rewriter.replaceOp(target, replacement->getResults());
2685 replacements.push_back(replacement);
2686 }
2687 transformResults.set(cast<OpResult>(getReplacement()), replacements);
2689}
2690
2691void transform::ReplaceOp::getEffects(
2692 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2693 consumesHandle(getTargetMutable(), effects);
2694 producesHandle(getOperation()->getOpResults(), effects);
2695 modifiesPayload(effects);
2696}
2697
2698LogicalResult transform::ReplaceOp::verify() {
2699 if (!getBodyRegion().hasOneBlock())
2700 return emitOpError() << "expected one block";
2701 if (std::distance(getBodyRegion().front().begin(),
2702 getBodyRegion().front().end()) != 1)
2703 return emitOpError() << "expected one operation in block";
2704 Operation *replacement = &getBodyRegion().front().front();
2705 if (replacement->getNumOperands() > 0)
2706 return replacement->emitOpError()
2707 << "expected replacement without operands";
2708 if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2709 replacement->getNumRegions() > 0)
2710 return replacement->emitOpError()
2711 << "expect op that is isolated from above";
2712 return success();
2713}
2714
2715//===----------------------------------------------------------------------===//
2716// ScalarizeOp
2717//===----------------------------------------------------------------------===//
2718
2719DiagnosedSilenceableFailure
2720transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2721 LinalgOp target,
2722 transform::ApplyToEachResultList &results,
2723 transform::TransformState &state) {
2724 scf::SCFTilingOptions tilingOptions;
2725 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2726 SmallVector<OpFoldResult> tileSizes;
2727 Location loc = target.getLoc();
2728 SmallVector<OpFoldResult> allShapeSizes =
2729 target.createFlatListOfOperandDims(b, loc);
2730 AffineMap map = target.getShapesToLoopsMap();
2731 if (!map)
2732 return tileSizes;
2733 SmallVector<OpFoldResult> shapeSizes =
2735 allShapeSizes);
2736 // If the shape size is dynamic, tile by 1.
2737 // Otherwise, do not tile (i.e. tile size 0).
2738 for (OpFoldResult shapeSize : shapeSizes) {
2739 tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2740 : b.getIndexAttr(1));
2741 }
2742 return tileSizes;
2743 });
2744 rewriter.setInsertionPoint(target);
2745 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2746 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2747 if (failed(maybeTilingResult))
2748 return emitDefaultDefiniteFailure(target);
2749
2750 if (target->getNumResults())
2751 rewriter.replaceOp(target, maybeTilingResult->replacements);
2752 else
2753 rewriter.eraseOp(target);
2754
2755 results.reserve(maybeTilingResult->tiledOps.size());
2756 for (Operation *tiled : maybeTilingResult->tiledOps)
2757 results.push_back(tiled);
2759}
2760
2761//===----------------------------------------------------------------------===//
2762// ConvertToLoopsOp
2763//===----------------------------------------------------------------------===//
2764
2765DiagnosedSilenceableFailure
2766transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2767 transform::TransformResults &results,
2768 transform::TransformState &state) {
2769 SmallVector<Operation *> loops;
2770 for (Operation *target : state.getPayloadOps(getTarget())) {
2771 auto tilingOp = dyn_cast<TilingInterface>(*target);
2772 if (!tilingOp) {
2773 DiagnosedSilenceableFailure diag =
2774 emitSilenceableError()
2775 << "expected the payload to implement TilingInterface";
2776 diag.attachNote(target->getLoc()) << "payload op";
2777 return diag;
2778 }
2779 rewriter.setInsertionPoint(target);
2780 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2781 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2782 if (failed(generatedLoops))
2783 return emitDefaultDefiniteFailure(target);
2784 for (scf::ForOp &loop : *generatedLoops) {
2785 loops.push_back(loop.getOperation());
2786 }
2787 rewriter.eraseOp(target);
2788 }
2789 results.set(cast<OpResult>(getResult()), loops);
2791}
2792
2793//===----------------------------------------------------------------------===//
2794// RewriteInDestinationPassingStyleOp
2795//===----------------------------------------------------------------------===//
2796
2797DiagnosedSilenceableFailure
2798transform::RewriteInDestinationPassingStyleOp::applyToOne(
2799 transform::TransformRewriter &rewriter, Operation *target,
2800 transform::ApplyToEachResultList &results,
2801 transform::TransformState &state) {
2802 rewriter.setInsertionPoint(target);
2803 FailureOr<Operation *> maybeResult =
2805 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2806 [&rewriter](auto op) {
2807 return rewriteInDestinationPassingStyle(rewriter, op);
2808 });
2809 if (failed(maybeResult))
2810 return emitDefaultSilenceableFailure(target);
2811 results.push_back(*maybeResult);
2813}
2814
2815//===----------------------------------------------------------------------===//
2816// SplitOp
2817//===----------------------------------------------------------------------===//
2818
2819DiagnosedSilenceableFailure
2820SplitOp::apply(transform::TransformRewriter &rewriter,
2821 TransformResults &results, TransformState &state) {
2822 // Collect the dynamic split points if provided.
2823 SmallVector<Operation *> payload =
2824 llvm::to_vector(state.getPayloadOps(getTarget()));
2825
2826 bool isMultiwaySplit = getMultiway();
2827
2828 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2829 return mlir::emitSilenceableFailure(getLoc())
2830 << "requires exactly one target when "
2831 "multiway split is enabled (got "
2832 << llvm::range_size(payload) << ")";
2833 }
2834
2835 SmallVector<OpFoldResult> chunkSizes;
2836
2837 if (!isMultiwaySplit)
2838 chunkSizes.reserve(payload.size());
2839
2840 if (getDynamicChunkSizes()) {
2842 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2843 chunkSizes = llvm::to_vector(llvm::map_range(
2844 state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2845 if (op->getNumResults() != 1 ||
2846 !op->getResult(0).getType().isIndex()) {
2847 diag = emitSilenceableError()
2848 << "expected dynamic split point handle to point to a "
2849 "single-result index-typed op";
2850 diag.attachNote(op->getLoc()) << "dynamic split point";
2851 }
2852 return OpFoldResult(op->getResult(0));
2853 }));
2854 } else {
2855 chunkSizes = llvm::to_vector(
2856 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2857 [](Attribute attr) { return OpFoldResult(attr); }));
2858 }
2859 if (diag.isSilenceableFailure())
2860 return diag;
2861
2862 // For multiway split, a single payload is expected to have multiple
2863 // split points.
2864 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2865 return emitDefiniteFailure()
2866 << "expected the dynamic split point handle to point to as "
2867 "many operations ("
2868 << chunkSizes.size() << ") as the target handle ("
2869 << payload.size() << ")";
2870 }
2871 } else {
2872 chunkSizes.resize(payload.size(),
2873 rewriter.getIndexAttr(getStaticChunkSizes()));
2874 }
2875
2876 auto checkStructuredOpAndDimensions =
2877 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2878 if (!linalgOp) {
2879 auto diag = emitSilenceableError() << "only applies to structured ops";
2880 diag.attachNote(loc) << "target op";
2881 return diag;
2882 }
2883
2884 if (getDimension() >= linalgOp.getNumLoops()) {
2885 auto diag = emitSilenceableError() << "dimension " << getDimension()
2886 << " does not exist in target op";
2887 diag.attachNote(loc) << "target op";
2888 return diag;
2889 }
2891 };
2892
2893 auto checkFailureInSplitting =
2894 [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2895 if (hasFailed) {
2896 auto diag = emitDefiniteFailure() << "internal failure in splitting";
2897 diag.attachNote(loc) << "target op";
2898 return diag;
2899 }
2901 };
2902
2903 SmallVector<Operation *> opList;
2904 if (isMultiwaySplit) {
2905
2906 // Split a single target operation at multiple points.
2907 TilingInterface head, tail;
2908 Operation *target = payload.front();
2909
2910 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2911
2912 // Check that the target is a valid LinalgOp with correct dimensions.
2913 DiagnosedSilenceableFailure diag =
2914 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2915 if (diag.isSilenceableFailure())
2916 return diag;
2917
2918 for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2919
2920 if (idx > 0)
2921 target = tail.getOperation();
2922
2923 if (!target)
2924 break;
2925
2926 linalgOp = cast<LinalgOp>(target);
2927 Location loc = target->getLoc();
2928
2929 rewriter.setInsertionPoint(linalgOp);
2930 std::tie(head, tail) = linalg::splitOp(
2931 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2932 getDimension(), chunkSize);
2933
2934 // Propagate errors.
2935 DiagnosedSilenceableFailure diag =
2936 checkFailureInSplitting(!head && !tail, loc);
2937 if (diag.isDefiniteFailure())
2938 return diag;
2939
2940 opList.push_back(head.getOperation());
2941 }
2942
2943 // Append any leftover parts to the end of the result list.
2944 if (tail)
2945 opList.push_back(tail.getOperation());
2946
2947 } else {
2948 // Split each target operation.
2949 SmallVector<Operation *> first, second;
2950 Operation *noSecondPart = nullptr;
2951 for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2952 Operation *target = std::get<0>(pair);
2953 Location loc = target->getLoc();
2954 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2955 DiagnosedSilenceableFailure diag =
2956 checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2957
2958 if (diag.isSilenceableFailure())
2959 return diag;
2960
2961 rewriter.setInsertionPoint(linalgOp);
2962 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2963 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2964 getDimension(), std::get<1>(pair));
2965
2966 // Propagate errors.
2967 DiagnosedSilenceableFailure diagSplit =
2968 checkFailureInSplitting(!first.back() && !second.back(), loc);
2969 if (diagSplit.isDefiniteFailure())
2970 return diag;
2971
2972 // Do not add null second parts.
2973 if (!second.back()) {
2974 noSecondPart = target;
2975 second.pop_back();
2976 }
2977 }
2978
2979 if (second.size() != first.size() && !second.empty()) {
2980 auto diag = emitSilenceableError()
2981 << "splitting does not produce the second part for a subset "
2982 "of targets";
2983 diag.attachNote()
2984 << "expected splitting to produce the second part of all "
2985 "or none of the targets";
2986 diag.attachNote(noSecondPart->getLoc())
2987 << "first target with no second part";
2988 return diag;
2989 }
2990
2991 opList.append(first);
2992 if (!second.empty())
2993 opList.append(second);
2994 }
2995 results.set(cast<OpResult>(getSplitList()), opList);
2997}
2998
2999void SplitOp::getEffects(
3000 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3001 consumesHandle(getTargetMutable(), effects);
3002 if (getDynamicChunkSizes())
3003 onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
3004 producesHandle(getOperation()->getOpResults(), effects);
3005 modifiesPayload(effects);
3006}
3007
3008ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
3009 OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
3010 IntegerAttr staticChunkSizes;
3011 if (parser.parseOperand(target) || parser.parseKeyword("after"))
3012 return failure();
3013
3014 OptionalParseResult dynamicPointParseResult =
3015 parser.parseOptionalOperand(dynamicChunkSizes);
3016 if (!dynamicPointParseResult.has_value()) {
3017 int64_t staticChunkSizesValue;
3018 if (failed(parser.parseInteger(staticChunkSizesValue)))
3019 return failure();
3020
3021 staticChunkSizes =
3022 parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
3023 }
3024
3025 Type targetType;
3026 if (parser.parseOptionalAttrDict(result.attributes) ||
3027 parser.parseColonType(targetType) ||
3028 parser.resolveOperand(target, targetType, result.operands)) {
3029 return failure();
3030 }
3031 if (dynamicPointParseResult.has_value()) {
3032 Type chunkSizesType;
3033 if (failed(*dynamicPointParseResult) || parser.parseComma() ||
3034 parser.parseType(chunkSizesType) ||
3035 parser.resolveOperand(dynamicChunkSizes, chunkSizesType,
3036 result.operands)) {
3037 return failure();
3038 }
3039
3040 staticChunkSizes =
3041 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
3042 }
3043
3044 result.addAttribute(
3045 SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
3046 staticChunkSizes);
3047 result.addTypes(targetType);
3048 return success();
3049}
3050
3051void SplitOp::print(OpAsmPrinter &printer) {
3052 printer << " " << getTarget() << " after ";
3053 int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
3054 if (staticChunkSize != ShapedType::kDynamic)
3055 printer << staticChunkSize;
3056 else
3057 printer << getDynamicChunkSizes();
3058 printer << " ";
3059 printer.printOptionalAttrDict(getOperation()->getAttrs(),
3060 {getStaticChunkSizesAttrName()});
3061 printer << " : " << getTarget().getType();
3062 if (staticChunkSize == ShapedType::kDynamic)
3063 printer << ", " << getDynamicChunkSizes().getType();
3064}
3065
3066LogicalResult SplitOp::verify() {
3067 if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
3068 (getDynamicChunkSizes() == nullptr)) {
3069 return emitOpError() << "expects either a dynamic or a static split "
3070 "point to be provided";
3071 }
3072 return success();
3073}
3074
3075//===----------------------------------------------------------------------===//
3076// SplitReductionOp
3077//===----------------------------------------------------------------------===//
3078
3079void transform::SplitReductionOp::build(
3080 OpBuilder &builder, OperationState &result, Value target,
3081 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
3082 bool useScalingAlgorithm, bool useAlloc) {
3083 MLIRContext *ctx = builder.getContext();
3084 result.addOperands(target);
3085 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
3086 builder.getI64IntegerAttr(splitFactor));
3087 result.addAttribute(
3088 SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
3089 builder.getI64IntegerAttr(insertSplitDimension));
3090 if (innerParallel) {
3091 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
3092 builder.getUnitAttr());
3093 }
3094 if (useScalingAlgorithm) {
3095 result.addAttribute(
3096 SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
3097 builder.getUnitAttr());
3098 }
3099 if (useAlloc) {
3100 result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
3101 builder.getUnitAttr());
3102 }
3103 auto resultType = transform::AnyOpType::get(ctx);
3104 result.addTypes({resultType, resultType, resultType, resultType});
3105}
3106
3107DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
3108 transform::TransformRewriter &rewriter, LinalgOp target,
3109 transform::ApplyToEachResultList &results,
3110 transform::TransformState &state) {
3111 ControlSplitReductionFn splitFn = [&](LinalgOp) {
3112 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
3113 unsigned(getInsertSplitDimension()),
3114 bool(getInnerParallel())};
3115 };
3116 rewriter.setInsertionPoint(target);
3117 FailureOr<SplitReductionResult> splitResult =
3118 (getUseScalingAlgorithm())
3119 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
3120 : splitReduction(rewriter, target, splitFn, getUseAlloc());
3121 if (failed(splitResult))
3122 return emitDefaultDefiniteFailure(target);
3123
3124 results.push_back(splitResult->initOrAlloc);
3125 results.push_back(splitResult->fillOp);
3126 results.push_back(splitResult->splitLinalgOp);
3127 results.push_back(splitResult->resultCombiningLinalgOp);
3129}
3130
3131//===----------------------------------------------------------------------===//
3132// TileReductionUsingForOp
3133//===----------------------------------------------------------------------===//
3134
3135void transform::TileReductionUsingForOp::build(
3136 OpBuilder &builder, OperationState &result, Value target,
3137 ArrayRef<int64_t> staticTileSizes) {
3138 // Call the default builder.
3139 // This is future-proof re mixed static-dynamic and setting up the proper
3140 // operands segment sizes attributes for multiple variadic operands.
3141 // In the absence of this, horrible bugs ensue.
3142 // TODO: support mixed static-dynamic (see TileUsingForallOp).
3143 MLIRContext *ctx = builder.getContext();
3144 auto opTy = transform::AnyOpType::get(ctx);
3145 auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
3146 build(builder, result,
3147 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3148 /*target=*/target,
3149 /*reduction_dims=*/nullptr,
3150 /*tile_sizes=*/staticTileSizesAttr);
3151}
3152
3153DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
3154 transform::TransformRewriter &rewriter, Operation *target,
3155 transform::ApplyToEachResultList &results,
3156 transform::TransformState &state) {
3157 rewriter.setInsertionPoint(target);
3158
3159 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3160 if (!partialReductionOp) {
3162 target->getLoc(),
3163 "Operation should implement PartialReductionOpInterface");
3164 }
3165
3166 SmallVector<unsigned> reductionDims =
3167 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3168 if (reductionDims.empty()) {
3169 for (auto [idx, iteratorType] :
3170 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3171 if (iteratorType == utils::IteratorType::reduction)
3172 reductionDims.push_back(idx);
3173 }
3174 }
3175
3176 scf::SCFTilingOptions options;
3177 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
3178 options.setReductionTilingStrategy(
3180 options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
3181 options.setReductionDims(reductionDims);
3182 FailureOr<scf::SCFTilingResult> result =
3183 scf::tileUsingSCF(rewriter, partialReductionOp, options);
3184
3185 if (failed(result)) {
3186 return emitSilenceableFailure(getLoc(),
3187 "failed to tile using partial reduction");
3188 }
3189 rewriter.replaceOp(target, result->replacements);
3190 for (Value initValue : result->initialValues)
3191 results.push_back(initValue.getDefiningOp());
3192 for (auto parallelTiledOp : result->tiledOps)
3193 results.push_back(parallelTiledOp);
3194 for (auto mergeOp : result->mergeOps)
3195 results.push_back(mergeOp);
3196 results.push_back(result->loops.front());
3198}
3199
3200//===----------------------------------------------------------------------===//
3201// TileReductionUsingForallOp
3202//===----------------------------------------------------------------------===//
3203
3204void transform::TileReductionUsingForallOp::build(
3205 OpBuilder &builder, OperationState &result, Value target,
3206 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3207 ArrayAttr mapping) {
3208 // Call the default builder.
3209 // This is future-proof re mixed static-dynamic and setting up the proper
3210 // operands segment sizes attributes for multiple variadic operands.
3211 // In the absence of this, horrible bugs ensue.
3212 // TODO: support mixed static-dynamic (see TileUsingForallOp).
3213 MLIRContext *ctx = builder.getContext();
3214 auto opTy = transform::AnyOpType::get(ctx);
3215 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3216 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3217 build(builder, result,
3218 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3219 /*target=*/target,
3220 /*reduction_dims=*/{},
3221 /*num_threads=*/staticNumThreadsAttr,
3222 /*tile_sizes=*/staticTileSizesAttr,
3223 /*mapping=*/mapping);
3224}
3225
3226DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3227 transform::TransformRewriter &rewriter, Operation *target,
3228 transform::ApplyToEachResultList &results,
3229 transform::TransformState &state) {
3230 rewriter.setInsertionPoint(target);
3231
3232 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3233 if (!partialReductionOp) {
3235 target->getLoc(),
3236 "Operation should implement PartialReductionOpInterface");
3237 }
3238 SmallVector<OpFoldResult> numThreads =
3239 getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
3240 SmallVector<OpFoldResult> tileSizes =
3242
3243 scf::SCFTilingOptions options;
3244 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3245 options.setReductionTilingStrategy(
3247 if (!getNumThreads().empty()) {
3248 options.setNumThreads(numThreads);
3249 } else {
3250 options.setTileSizes(tileSizes);
3251 }
3252 if (auto mapping = getMapping()) {
3253 options.setMapping(mapping.value().getValue());
3254 }
3255 SmallVector<unsigned> reductionDims =
3256 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3257 if (reductionDims.empty()) {
3258 for (auto [idx, iteratorType] :
3259 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3260 if (iteratorType == utils::IteratorType::reduction)
3261 reductionDims.push_back(idx);
3262 }
3263 }
3264 options.setReductionDims(reductionDims);
3265 FailureOr<scf::SCFTilingResult> result =
3266 scf::tileUsingSCF(rewriter, partialReductionOp, options);
3267
3268 if (failed(result)) {
3269 auto diag = emitSilenceableError() << "could not tile reduction";
3270 return diag;
3271 }
3272 rewriter.replaceOp(target, result->replacements);
3273
3274 for (Value initValue : result->initialValues)
3275 results.push_back(initValue.getDefiningOp());
3276 for (auto parallelTiledOp : result->tiledOps)
3277 results.push_back(parallelTiledOp);
3278 for (auto mergeOp : result->mergeOps)
3279 results.push_back(mergeOp);
3280 results.push_back(result->loops.front());
3282}
3283
3284//===----------------------------------------------------------------------===//
3285// ContinuousTileSizesOp
3286//===----------------------------------------------------------------------===//
3287
3288DiagnosedSilenceableFailure
3289transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3290 TransformResults &transformResults,
3291 TransformState &state) {
3292
3293 SmallVector<Operation *> targetOps =
3294 llvm::to_vector(state.getPayloadOps(getTarget()));
3295
3296 if (!llvm::hasSingleElement(targetOps)) {
3297 return mlir::emitSilenceableFailure(getLoc())
3298 << "requires exactly one target (got " << llvm::range_size(targetOps)
3299 << ")";
3300 }
3301
3302 Operation *target = *targetOps.begin();
3303 auto linalgOp = dyn_cast<LinalgOp>(target);
3304 auto tileableOp = dyn_cast<TilingInterface>(target);
3305
3306 if (!linalgOp)
3307 return emitDefiniteFailure() << "expected Linalg Op";
3308
3309 OpBuilder builder(linalgOp.getContext());
3310
3311 if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
3312 if (linalgOp.hasDynamicShape()) {
3313 auto diag = emitSilenceableError()
3314 << "cannot compute parametric tile sizes for dynamically "
3315 "shaped payload op";
3316 diag.attachNote(linalgOp->getLoc()) << "payload op";
3317 return diag;
3318 }
3319
3320 FailureOr<StaticContinuousTileSizeSpecification> spec =
3321 computeStaticContinuousTileSizes(linalgOp, getDimension(),
3322 getTargetSize());
3323 if (failed(spec)) {
3324 return emitSilenceableError()
3325 << "failed to compute multi-size tiling sizes";
3326 }
3327
3328 SmallVector<int64_t> chunkSizes;
3329
3330 for (auto &&[tileSize, tripCount] :
3331 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3332 chunkSizes.push_back(tileSize * tripCount);
3333
3334 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3335 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3336 return builder.getI64IntegerAttr(value);
3337 });
3338 };
3339 transformResults.setParams(cast<OpResult>(getTileSizes()),
3340 getI64AttrsFromI64(spec->tileSizes));
3341 transformResults.setParams(cast<OpResult>(getChunkSizes()),
3342 getI64AttrsFromI64(chunkSizes));
3343
3345 }
3346
3347 builder.setInsertionPoint(linalgOp);
3348
3349 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
3350 unsigned dimension = getDimension();
3351
3352 FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
3353 builder, tileableOp, dimension, targetSize, true);
3354 if (failed(spec)) {
3355 return emitSilenceableError() << "could not generate tile size computation";
3356 }
3357
3358 AffineExpr s0 = builder.getAffineSymbolExpr(0);
3359 AffineExpr s1 = builder.getAffineSymbolExpr(1);
3360 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3361 return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
3362 ofrs);
3363 };
3364
3365 SmallVector<Value> chunkSizes;
3366 Value splitPoint;
3367 for (auto &&[tileSize, tripCount] :
3368 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3369 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3370 chunkSizes.push_back(splitPoint);
3371 }
3372
3373 auto getDefiningOps = [&](ArrayRef<Value> values) {
3374 return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3375 return value.getDefiningOp();
3376 });
3377 };
3378
3379 transformResults.set(cast<OpResult>(getTileSizes()),
3380 getDefiningOps(spec->tileSizes));
3381 transformResults.set(cast<OpResult>(getChunkSizes()),
3382 getDefiningOps(chunkSizes));
3383
3385}
3386
3387LogicalResult transform::ContinuousTileSizesOp::verify() {
3388
3389 if (getTileSizes().getType() != getChunkSizes().getType()) {
3390 return emitOpError() << "expects all results type to be the same";
3391 }
3392
3393 return success();
3394}
3395
3396void transform::ContinuousTileSizesOp::getEffects(
3397 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3398 if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
3399 onlyReadsPayload(effects);
3400 else
3401 modifiesPayload(effects);
3402 onlyReadsHandle(getTargetMutable(), effects);
3403 producesHandle(getOperation()->getOpResults(), effects);
3404}
3405
3407 Type targetType, Type tileSizes,
3408 Type) {
3409 printer.printFunctionalType(TypeRange{targetType}, TypeRange{tileSizes});
3410}
3411
3413 Type &targetType,
3414 Type &tileSizesType,
3415 Type &chunkSizesType) {
3416 FunctionType funcType;
3417 llvm::SMLoc typeLoc = parser.getCurrentLocation();
3418 if (failed(parser.parseType<FunctionType>(funcType)))
3419 return failure();
3420
3421 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3422 parser.emitError(typeLoc) << "expects a trailing functional type with one "
3423 "argument and one result";
3424 }
3425 targetType = funcType.getInput(0);
3426 tileSizesType = chunkSizesType = funcType.getResult(0);
3427
3428 return success();
3429}
3430
3431//===----------------------------------------------------------------------===//
3432// TileUsingForOp
3433//===----------------------------------------------------------------------===//
3434
3435void transform::TileUsingForOp::build(
3436 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3437 Value target, ArrayRef<int64_t> staticTileSizes,
3438 ArrayRef<int64_t> interchange,
3439 std::optional<ArrayRef<bool>> scalableSizes) {
3440 return build(builder, result, loopTypes,
3441 /*target=*/target,
3442 /*mixedTileSizes=*/
3443 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3444 interchange, scalableSizes);
3445}
3446
3447void transform::TileUsingForOp::build(
3448 OpBuilder &builder, OperationState &result, Value target,
3449 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3450 std::optional<ArrayRef<bool>> scalableSizes) {
3451 build(builder, result, target,
3452 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3453 interchange, scalableSizes);
3454}
3455
3456void transform::TileUsingForOp::build(
3457 OpBuilder &builder, OperationState &result, Value target,
3458 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3459 std::optional<ArrayRef<bool>> scalableSizes) {
3460 // Loop types are automaticaly splat by the callee, setting up one is
3461 // enough.
3462 SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
3463 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3464 scalableSizes);
3465}
3466
3467void transform::TileUsingForOp::build(
3468 OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3469 Value target, ArrayRef<OpFoldResult> mixedTileSizes,
3470 ArrayRef<int64_t> interchange,
3471 std::optional<ArrayRef<bool>> scalableSizes) {
3472 SmallVector<int64_t> staticTileSizes;
3473 SmallVector<Value> dynamicTileSizes;
3474 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3475 // Call the default builder which sets up the proper operands segment sizes
3476 // attributes for multiple variadic operands. In the absence of this,
3477 // horrible bugs ensue.
3478 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3479 unsigned numExpectedLoops =
3480 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3481 SmallVector<Type> resultTypes;
3482 resultTypes.reserve(numExpectedLoops);
3483 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3484 "expected one loop type or as many as loops");
3485 if (loopTypes.size() == 1)
3486 resultTypes.append(numExpectedLoops, loopTypes[0]);
3487 else
3488 llvm::append_range(resultTypes, loopTypes);
3489 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
3490 if (scalableSizes.has_value())
3491 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3492 build(builder, result, /*tiled_linalg_op=*/target.getType(),
3493 /*loops=*/resultTypes,
3494 /*target=*/target,
3495 /*dynamic_sizes=*/dynamicTileSizes,
3496 /*static_sizes=*/staticTileSizesAttr,
3497 /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
3498 /*scalable_sizes=*/expandedScalableSizes);
3499}
3500
3501LogicalResult transform::TileUsingForOp::verify() {
3502 if (getMixedSizes().size() != getScalableSizes().size())
3503 return emitOpError("expected same number of sizes (")
3504 << getMixedSizes().size() << ") and scalable sizes ("
3505 << getScalableSizes().size() << ")";
3506 ArrayRef<int64_t> staticSizes = getStaticSizes();
3507 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3508 if (getLoops().size() != numExpectedLoops)
3509 return emitOpError("expected number of loops to tile (")
3510 << numExpectedLoops << ") to match number of `loops` results ("
3511 << getLoops().size() << ")";
3512 return success();
3513}
3514
3515DiagnosedSilenceableFailure
3516transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3517 TransformResults &transformResults,
3518 TransformState &state) {
3519 ArrayRef<int64_t> tileSizes = getStaticSizes();
3520
3521 SmallVector<Operation *> targets =
3522 llvm::to_vector(state.getPayloadOps(getTarget()));
3523 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3524 SmallVector<SmallVector<int64_t>> paramSizes;
3525 dynamicSizeProducers.reserve(getDynamicSizes().size());
3526 paramSizes.reserve(getDynamicSizes().size());
3527 for (Value transformValue : getDynamicSizes()) {
3528 if (isa<ParamType>(transformValue.getType())) {
3529 dynamicSizeProducers.push_back({});
3530 ArrayRef<Attribute> params = state.getParams(transformValue);
3531 paramSizes.push_back(
3532 llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
3533 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3534 })));
3535
3536 if (paramSizes.back().size() != targets.size()) {
3537 DiagnosedSilenceableFailure diag =
3538 emitSilenceableError()
3539 << "expected as many parameter values ("
3540 << dynamicSizeProducers.back().size() << ") as target ops ("
3541 << targets.size() << ")";
3542 diag.attachNote(transformValue.getLoc()) << "for this parameter";
3543 return diag;
3544 }
3545
3546 continue;
3547 }
3548 paramSizes.push_back({});
3549 dynamicSizeProducers.push_back(
3550 llvm::to_vector(state.getPayloadOps(transformValue)));
3551
3552 if (dynamicSizeProducers.back().size() != targets.size()) {
3553 DiagnosedSilenceableFailure diag =
3554 emitSilenceableError()
3555 << "expected as many dynamic size-producing operations ("
3556 << dynamicSizeProducers.back().size() << ") as target ops ("
3557 << targets.size() << ")";
3558 diag.attachNote(transformValue.getLoc()) << "for this handle";
3559 return diag;
3560 }
3561
3562 for (Operation *op : dynamicSizeProducers.back()) {
3563 if (op->getNumResults() == 1 &&
3564 isa<IndexType>(op->getResult(0).getType())) {
3565 continue;
3566 }
3567
3568 DiagnosedSilenceableFailure diag =
3569 emitSilenceableError() << "expected sizes to be produced by ops "
3570 "with a single index-type result";
3571 diag.attachNote(op->getLoc()) << "size producer op";
3572 diag.attachNote(transformValue.getLoc()) << "for this handle";
3573 return diag;
3574 }
3575 }
3576
3577 SmallVector<Operation *> tiled;
3578 SmallVector<SmallVector<Operation *, 4>, 4> loops;
3579 loops.resize(getLoops().size());
3580 auto scalableSizes = getScalableSizes();
3581 for (auto [i, op] : llvm::enumerate(targets)) {
3582 auto tilingInterface = dyn_cast<TilingInterface>(op);
3583 if (!tilingInterface) {
3584 DiagnosedSilenceableFailure diag =
3585 emitSilenceableError()
3586 << "only ops implementing TilingInterface are supported";
3587 diag.attachNote(op->getLoc()) << "target op";
3588 return diag;
3589 }
3590 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3591 DiagnosedSilenceableFailure diag =
3592 emitSilenceableError()
3593 << "too many tiles provided, expected at most "
3594 << tilingInterface.getLoopIteratorTypes().size() << " found "
3595 << tileSizes.size();
3596 diag.attachNote(op->getLoc()) << "target op";
3597 return diag;
3598 }
3599
3600 scf::SCFTilingOptions tilingOptions;
3601 if (tileSizes.empty()) {
3602 tilingOptions.setTileSizeComputationFunction(
3603 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
3604 return {};
3605 });
3606 } else {
3607 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3608 Operation *) {
3609 SmallVector<OpFoldResult> sizes;
3610 sizes.reserve(tileSizes.size());
3611 unsigned dynamicIdx = 0;
3612
3613 for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3614 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3615 if (scalableSizes[ofrIdx]) {
3617 b, getLoc(), cast<IntegerAttr>(attr).getInt());
3618 Value vscale =
3619 vector::VectorScaleOp::create(b, getLoc(), b.getIndexType());
3620 sizes.push_back(
3621 arith::MulIOp::create(b, getLoc(), val, vscale).getResult());
3622 } else {
3623 sizes.push_back(attr);
3624 }
3625 continue;
3626 }
3627 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3628 ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3629 ++dynamicIdx;
3630 assert((dynamicSizes.empty() ^ params.empty()) &&
3631 "expected either dynamic sizes or parameters");
3632 if (!params.empty()) {
3633 sizes.push_back(b.getIndexAttr(params[index]));
3634 } else {
3635 sizes.push_back(dynamicSizes[index]->getResult(0));
3636 }
3637 }
3638 return sizes;
3639 });
3640 }
3641
3642 tilingOptions.setInterchange(getInterchange());
3643 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3644 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3645 if (failed(maybeTilingResult))
3647
3648 rewriter.replaceOp(op, maybeTilingResult->replacements);
3649
3650 tiled.append(maybeTilingResult->tiledOps);
3651 for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3652 loops[en2.index()].push_back(en2.value());
3653 }
3654
3655 transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3656 for (const auto &en : llvm::enumerate(loops))
3657 transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3658
3660}
3661
3662SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
3663 ValueRange dynamic = getDynamicSizes();
3664 ArrayRef<int64_t> tileSizes = getStaticSizes();
3665 SmallVector<OpFoldResult> results;
3666 results.reserve(tileSizes.size());
3667 unsigned dynamicPos = 0;
3668 Builder builder(getContext());
3669 for (int64_t size : tileSizes) {
3670 if (size == ShapedType::kDynamic) {
3671 results.push_back(dynamic[dynamicPos++]);
3672 } else {
3673 results.push_back(builder.getIndexAttr(size));
3674 }
3675 }
3676 return results;
3677}
3678
3679void transform::TileUsingForOp::getEffects(
3680 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3681 consumesHandle(getTargetMutable(), effects);
3682 onlyReadsHandle(getDynamicSizesMutable(), effects);
3683 producesHandle(getOperation()->getOpResults(), effects);
3684 modifiesPayload(effects);
3685}
3686
3687//===----------------------------------------------------------------------===//
3688// TileUsingForallOp
3689//===----------------------------------------------------------------------===//
3690
3691void transform::TileUsingForallOp::build(OpBuilder &builder,
3692 OperationState &result, Value target,
3693 ArrayRef<int64_t> staticTileSizes,
3694 transform::TileSizesSpec,
3695 ArrayAttr mapping) {
3696 return build(builder, result,
3697 /*target=*/target,
3698 /*mixedTileSizes=*/
3699 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3700 /*_=*/TileSizesSpec(),
3701 /*mapping=*/mapping);
3702}
3703
3704void transform::TileUsingForallOp::build(OpBuilder &builder,
3705 OperationState &result, Value target,
3706 ArrayRef<OpFoldResult> mixedTileSizes,
3707 transform::TileSizesSpec,
3708 ArrayAttr mapping) {
3709 SmallVector<int64_t> staticTileSizes;
3710 SmallVector<Value> dynamicTileSizes;
3711 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3712 // Call the default builder which sets up the proper operands segment sizes
3713 // attributes for multiple variadic operands. In the absence of this,
3714 // horrible bugs ensue.
3715 MLIRContext *ctx = builder.getContext();
3716 auto operationType = transform::AnyOpType::get(ctx);
3717 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3718 build(builder, result,
3719 /*resultTypes=*/TypeRange{operationType, operationType},
3720 /*target=*/target,
3721 /*num_threads=*/ValueRange{},
3722 /*tile_sizes=*/dynamicTileSizes,
3723 /*packed_num_threads=*/Value(),
3724 /*packed_tile_sizes=*/Value(),
3725 /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3726 /*static_tile_sizes=*/staticTileSizesAttr,
3727 /*mapping=*/mapping);
3728}
3729
3730void transform::TileUsingForallOp::build(OpBuilder &builder,
3731 OperationState &result, Value target,
3732 ArrayRef<int64_t> staticNumThreads,
3733 transform::NumThreadsSpec,
3734 ArrayAttr mapping) {
3735 return build(builder, result, target,
3736 getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3737 NumThreadsSpec(), mapping);
3738}
3739
3740void transform::TileUsingForallOp::build(OpBuilder &builder,
3741 OperationState &result, Value target,
3742 ArrayRef<OpFoldResult> mixedNumThreads,
3743 transform::NumThreadsSpec,
3744 ArrayAttr mapping) {
3745 SmallVector<int64_t> staticNumThreads;
3746 SmallVector<Value> dynamicNumThreads;
3747 dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3748 staticNumThreads);
3749 // Call the default builder which sets up the proper operands segment sizes
3750 // attributes for multiple variadic operands. In the absence of this,
3751 // horrible bugs ensue.
3752 MLIRContext *ctx = builder.getContext();
3753 auto operationType = transform::AnyOpType::get(ctx);
3754 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3755 build(builder, result,
3756 /*resultTypes=*/TypeRange{operationType, operationType},
3757 /*target=*/target,
3758 /*num_threads=*/dynamicNumThreads,
3759 /*tile_sizes=*/ValueRange{},
3760 /*packed_num_threads=*/Value(),
3761 /*packed_tile_sizes=*/Value(),
3762 /*static_num_threads=*/staticNumThreadsAttr,
3763 /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3764 /*mapping=*/mapping);
3765}
3766
3767/// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3768/// normalized upper bound.
3769static SmallVector<OpFoldResult>
3772 ArrayRef<OpFoldResult> steps) {
3773 AffineExpr s0, s1, s2;
3774 bindSymbols(rewriter.getContext(), s0, s1, s2);
3775 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3776 SmallVector<OpFoldResult> normalizedUbs;
3777 for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3779 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3780 normalizedUbs.push_back(normalizedUb);
3781 }
3782 return normalizedUbs;
3783}
3784
3785/// When a loop is normalized, the uses of the induction variable within the
3786/// loop need to replaced with `original_lb + old_iv * original_step`.
3788 Location loc, ValueRange ivs,
3790 ArrayRef<OpFoldResult> steps) {
3791 AffineExpr s0, s1;
3792 AffineExpr d0;
3793 bindSymbols(rewriter.getContext(), s0, s1);
3794 bindDims(rewriter.getContext(), d0);
3795 AffineExpr denormExpr = s0 + d0 * s1;
3796 SmallVector<Value> denormalizedIvs;
3797
3798 for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3800 rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3801 denormalizedIvs.push_back(
3802 getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3803 }
3804 return denormalizedIvs;
3805}
3806
3807/// Given a `scf.forall` loop return a loop op with the loop bounds
3808/// normalized.
3809/// TODO: Replace this with a general utility to normalize `scf.forall`.
3810/// At the time of writing, this wasnt done since adding this to `scf`
3811/// dialect would disallow using of `affine.apply` operations due
3812/// to cyclic dependencies. To avoid churn in lit tests
3813/// with the change this was added with, defer that to a follow up.
3814static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3815 scf::ForallOp loop) {
3816 SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3817 SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3818 SmallVector<OpFoldResult> steps = loop.getMixedStep();
3819
3820 if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
3821 return loop;
3822 }
3823
3824 Location loc = loop.getLoc();
3825 SmallVector<OpFoldResult> normalizedUbs =
3826 normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3827 SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3828 rewriter.getIndexAttr(0));
3829 SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3830 rewriter.getIndexAttr(1));
3831
3832 auto normalizedForallOp = scf::ForallOp::create(
3833 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3834 loop.getOutputs(), loop.getMapping(),
3835 [](OpBuilder &, Location, ValueRange) {});
3836
3837 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3838 OpBuilder::InsertionGuard g(rewriter);
3839 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3840 rewriter.setInsertionPointToStart(normalizedLoopBlock);
3841
3842 SmallVector<Value> argValues =
3843 denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3844 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3845 normalizedForallOp.getRegionIterArgs().end());
3846 Block *origLoopBlock = loop.getBody();
3847 rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3848
3849 rewriter.replaceOp(loop, normalizedForallOp);
3850 return normalizedForallOp;
3851}
3852
3854 RewriterBase &rewriter, transform::TransformState &state,
3855 TransformOpInterface transformOp, Operation *target,
3856 ArrayRef<OpFoldResult> mixedNumThreads,
3857 ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3858 scf::SCFTilingResult &tilingResult) {
3859 // Transform all targets one by one.
3860 auto tileableOp = dyn_cast<TilingInterface>(target);
3861 if (!tileableOp) {
3863 transformOp.emitSilenceableError()
3864 << "only TilingInterface ops are supported";
3865 diag.attachNote(target->getLoc()) << "target op";
3866 return diag;
3867 }
3868 rewriter.setInsertionPoint(tileableOp);
3869 scf::SCFTilingOptions options;
3870 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3871 if (!mixedNumThreads.empty()) {
3872 options.setNumThreads(mixedNumThreads);
3873 } else {
3874 options.setTileSizes(mixedTileSizes);
3875 }
3876 if (mapping) {
3877 options.setMapping(mapping.value().getValue());
3878 }
3879 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3880 scf::tileUsingSCF(rewriter, tileableOp, options);
3881
3882 if (failed(maybeTilingResult))
3883 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3884
3885 rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3886
3887 tilingResult = *maybeTilingResult;
3888
3889 if (mixedNumThreads.empty()) {
3890 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3891 OpBuilder::InsertionGuard g(rewriter);
3892 rewriter.setInsertionPoint(generatedForallOp);
3893 scf::ForallOp normalizedForallOp =
3894 normalizeForallLoopOp(rewriter, generatedForallOp);
3895 tilingResult.loops.front() = normalizedForallOp;
3896 }
3897
3899}
3900
3901DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3903 transform::TransformResults &transformResults,
3905 auto transformOp = cast<TransformOpInterface>(getOperation());
3906
3907 // Result payload ops.
3909 SmallVector<Operation *> tiledOps;
3910
3911 // Unpack handles.
3912 SmallVector<OpFoldResult> mixedNumThreads;
3914 getPackedNumThreads()
3916 state, transformOp, mixedNumThreads, getPackedNumThreads())
3918 state, transformOp, mixedNumThreads, getMixedNumThreads());
3919 if (!status.succeeded())
3920 return status;
3921 SmallVector<OpFoldResult> mixedTileSizes;
3922 status = getPackedTileSizes()
3924 state, transformOp, mixedTileSizes, getPackedTileSizes())
3926 state, transformOp, mixedTileSizes, getMixedTileSizes());
3927 if (!status.succeeded())
3928 return status;
3929
3930 for (Operation *target : state.getPayloadOps(getTarget())) {
3931 scf::SCFTilingResult tilingResult;
3933 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3934 getMapping(), tilingResult);
3935 if (!diag.succeeded())
3936 return diag;
3937 tileOps.push_back(tilingResult.loops.front());
3938 tiledOps.append(tilingResult.tiledOps);
3939 }
3940
3941 transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3942 transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3943
3945}
3946
3947void transform::TileUsingForallOp::getEffects(
3948 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3949 consumesHandle(getTargetMutable(), effects);
3950 onlyReadsHandle(getTileSizesMutable(), effects);
3951 onlyReadsHandle(getNumThreadsMutable(), effects);
3952 onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3953 onlyReadsHandle(getPackedTileSizesMutable(), effects);
3954 producesHandle(getOperation()->getOpResults(), effects);
3955 modifiesPayload(effects);
3956}
3957
3958SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3959 Builder b(getContext());
3960 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3961}
3962
3963SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3964 Builder b(getContext());
3965 return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3966}
3967
3968LogicalResult TileUsingForallOp::verify() {
3969 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3970 static_cast<int>(getPackedNumThreads() != Value());
3971 if (numThreadsSpec > 1)
3972 return emitOpError(
3973 "num_threads and packed_num_threads are mutually exclusive");
3974 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3975 static_cast<int>(getPackedTileSizes() != Value());
3976 if (tileSizesSpec > 1)
3977 return emitOpError(
3978 "tile_sizes and packed_tile_sizes are mutually exclusive");
3979 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3980 return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3981 "must be specified");
3982 return success();
3983}
3984
3985//===----------------------------------------------------------------------===//
3986// VectorizeChildrenAndApplyPatternsOp
3987//===----------------------------------------------------------------------===//
3988
3989void transform::VectorizeChildrenAndApplyPatternsOp::build(
3990 OpBuilder &builder, OperationState &result, Value target,
3991 bool foldTypeExtensionsIntoContract, bool vectorizePadding,
3992 bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3993 result.addOperands(target);
3994 if (foldTypeExtensionsIntoContract) {
3995 result.addAttribute(
3996 VectorizeChildrenAndApplyPatternsOp::
3997 getFoldTypeExtensionsIntoContractAttrName(result.name),
3998 builder.getUnitAttr());
3999 }
4000 if (vectorizePadding) {
4001 result.addAttribute(
4002 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
4003 result.name),
4004 builder.getUnitAttr());
4005 }
4006 if (vectorizeExtract) {
4007 result.addAttribute(
4008 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
4009 result.name),
4010 builder.getUnitAttr());
4011 }
4012 if (flatten1DDepthwiseConv) {
4013 result.addAttribute(
4014 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
4015 result.name),
4016 builder.getUnitAttr());
4017 }
4018 result.addTypes(transform::AnyOpType::get(builder.getContext()));
4019}
4020
4021namespace {
4022/// This is an helper only to call vectorize via a pattern inside of
4023/// VectorizeChildrenAndApplyPatternsOp::applyToOne.
4024struct VectorizationPattern : public RewritePattern {
4025 explicit VectorizationPattern(MLIRContext *context,
4026 bool vectorizeExtract = false,
4027 bool flattenConv = false)
4028 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
4029 vectorizeNDExtract(vectorizeExtract),
4030 flatten1DDepthwiseConv(flattenConv) {}
4031 LogicalResult matchAndRewrite(Operation *op,
4032 PatternRewriter &rewriter) const override {
4034 return rewriter.notifyMatchFailure(op,
4035 "Unsupported Op, cannot vectorize");
4036 FailureOr<VectorizationResult> vectorResults =
4037 vectorize(rewriter, op, /*inputVectorSizes=*/{},
4038 /*inputScalableVecDims=*/{}, vectorizeNDExtract,
4039 flatten1DDepthwiseConv);
4040 if (failed(vectorResults))
4041 return failure();
4042 rewriter.replaceOp(op, vectorResults->replacements);
4043 return success();
4044 }
4045
4046private:
4047 /// Controls whether to vectorize `tensor.extract` when the input tensor is
4048 /// rank >= 2.
4049 bool vectorizeNDExtract = false;
4050 /// Controls whether to "flatten" the channel dimension when vectorising 1D
4051 /// depthwise convolutions. This should lead to bette vectorization for
4052 /// tensors with a low number of channel dimensions.
4053 bool flatten1DDepthwiseConv = false;
4054};
4055} // namespace
4056
4057DiagnosedSilenceableFailure
4058transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
4059 transform::TransformRewriter &rewriter, Operation *target,
4060 transform::ApplyToEachResultList &results,
4061 transform::TransformState &state) {
4062 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
4063 auto diag = this->emitOpError("requires isolated-from-above targets");
4064 diag.attachNote(target->getLoc()) << "non-isolated target";
4066 }
4067
4068 MLIRContext *ctx = getContext();
4069 RewritePatternSet patterns(ctx);
4070 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
4071 getFlatten_1dDepthwiseConv());
4072
4073 if (!getDisableTransferPermutationMapLoweringPatterns())
4075
4076 if (!getDisableMultiReductionToContractPatterns())
4078
4080
4081 patterns.add<linalg::LinalgCopyVTRForwardingPattern,
4082 linalg::LinalgCopyVTWForwardingPattern>(ctx,
4083 /*benefit=*/2);
4084 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
4085 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
4087
4088 patterns.add<CopyVectorizationPattern>(ctx);
4089
4090 if (getFoldTypeExtensionsIntoContract())
4092
4093 if (getVectorizePadding()) {
4095 // This creates an alternative path for lowering tensor.pad - by
4096 // decomposing it into e.g. linalg.fill.
4098 }
4100
4101 TrackingListener listener(state, *this);
4102 if (failed(
4104 GreedyRewriteConfig().setListener(&listener))))
4105 return emitDefaultDefiniteFailure(target);
4106
4107 results.push_back(target);
4109}
4110
4111//===----------------------------------------------------------------------===//
4112// VectorizeOp
4113//===----------------------------------------------------------------------===//
4114
4115DiagnosedSilenceableFailure transform::VectorizeOp::apply(
4116 transform::TransformRewriter &rewriter,
4117 mlir::transform::TransformResults &transformResults,
4118 mlir::transform::TransformState &state) {
4119 auto targets = state.getPayloadOps(getTarget());
4120 if (std::empty(targets))
4122 auto transformOp = cast<TransformOpInterface>(getOperation());
4123 SmallVector<int64_t> vectorSizes;
4124 DiagnosedSilenceableFailure status = reifyMixedParamAndHandleResults(
4125 state, transformOp, getMixedVectorSizes(), vectorSizes);
4126 if (!status.succeeded())
4127 return status;
4128
4129 // TODO: Check that the correct number of vectorSizes was provided.
4130 for (Operation *target : targets) {
4132 return mlir::emitSilenceableFailure(target->getLoc())
4133 << "Unsupported Op, cannot vectorize";
4134 }
4135 FailureOr<VectorizationResult> vectorResults =
4136 linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
4137 getVectorizeNdExtract().value_or(false),
4138 /*flatten1DDepthwiseConv=*/false,
4139 getAssumeDynamicDimsMatchVecSizes().value_or(false),
4140 getCreateNamedContraction().value_or(false));
4141 if (failed(vectorResults)) {
4142 return mlir::emitSilenceableFailure(target->getLoc())
4143 << "Attempted to vectorize, but failed";
4144 }
4145 rewriter.replaceOp(target, vectorResults->replacements);
4146 }
4147
4149}
4150
4151void transform::VectorizeOp::getEffects(
4152 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4153 consumesHandle(getTargetMutable(), effects);
4154 onlyReadsHandle(getVectorSizesMutable(), effects);
4155 modifiesPayload(effects);
4156}
4157
4158SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
4159 OpBuilder b(getContext());
4160 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
4161}
4162
4163LogicalResult transform::VectorizeOp::verify() {
4164 if (getStaticVectorSizes().size() != getScalableSizes().size())
4165 return emitOpError("expected same number of vector sizes (")
4166 << getStaticVectorSizes().size() << ") and scalable sizes ("
4167 << getScalableSizes().size() << ")";
4168 return success();
4169}
4170
4171//===----------------------------------------------------------------------===//
4172// HoistRedundantVectorTransfersOp
4173//===----------------------------------------------------------------------===//
4174
4175DiagnosedSilenceableFailure
4176transform::HoistRedundantVectorTransfersOp::applyToOne(
4177 transform::TransformRewriter &rewriter, func::FuncOp target,
4178 transform::ApplyToEachResultList &results,
4179 transform::TransformState &state) {
4180 // WARNING: This hoisting does not model parallelism and is generally
4181 // incorrect when used on distributed loops with memref semantics!
4182 // TODO: obsolete and should be retired.
4183 linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
4184 results.push_back(target);
4186}
4187
4188//===----------------------------------------------------------------------===//
4189// HoistRedundantVectorBroadcastsOp
4190//===----------------------------------------------------------------------===//
4191
4192DiagnosedSilenceableFailure
4193transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4194 transform::TransformRewriter &rewriter, mlir::Operation *target,
4195 transform::ApplyToEachResultList &results,
4196 transform::TransformState &state) {
4197 rewriter.setInsertionPoint(target);
4199 results.push_back(target);
4201}
4202
4203//===----------------------------------------------------------------------===//
4204// ConvertConv2DToImg2ColOp.
4205//===----------------------------------------------------------------------===//
4206
4207DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
4208 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4209 transform::ApplyToEachResultList &results,
4210 transform::TransformState &state) {
4211 rewriter.setInsertionPoint(target);
4212 auto maybeTransformed =
4214 target)
4215 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4216 return rewriteInIm2Col(rewriter, op);
4217 })
4218 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4219 return rewriteInIm2Col(rewriter, op);
4220 })
4221 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4222 return rewriteInIm2Col(rewriter, op);
4223 })
4224 .Case([&](linalg::Conv2DNchwFchwOp op) {
4225 return rewriteInIm2Col(rewriter, op);
4226 })
4227 .Default([&](Operation *op) {
4228 return rewriter.notifyMatchFailure(op, "not supported");
4229 });
4230 if (failed(maybeTransformed))
4231 return emitDefaultSilenceableFailure(target);
4232 // Handle to the operation producing the img2col tensor.
4233 results.push_back(maybeTransformed->first);
4234 // Handle to the operation that replaces the original convolution.
4235 results.push_back(maybeTransformed->second);
4237}
4238
4239//===----------------------------------------------------------------------===//
4240// FlattenElementwiseLinalgOp.
4241//===----------------------------------------------------------------------===//
4242
4243DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4244 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4245 transform::ApplyToEachResultList &results,
4246 transform::TransformState &state) {
4247 rewriter.setInsertionPoint(target);
4248 if (!isElementwise(target))
4249 return mlir::emitSilenceableFailure(target->getLoc())
4250 << "only elementwise flattening is supported";
4251
4252 // If rank <= 1, do nothing
4253 if (target.getNumLoops() <= 1) {
4254 results.push_back(target);
4256 }
4257
4258 // Attempt to flatten all dims to one.
4259 ReassociationIndices reassociation(target.getNumLoops());
4260 std::iota(reassociation.begin(), reassociation.end(), 0);
4261 auto maybeFlattened =
4262 collapseOpIterationDims(target, reassociation, rewriter);
4263 if (failed(maybeFlattened))
4264 return mlir::emitSilenceableFailure(target->getLoc())
4265 << "attempted to flatten, but failed";
4266 results.push_back(maybeFlattened->collapsedOp);
4267 rewriter.replaceOp(target, maybeFlattened->results);
4269}
4270
4271//===----------------------------------------------------------------------===//
4272// TransposeConv2DOp
4273//===----------------------------------------------------------------------===//
4274
4275DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4276 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4277 transform::ApplyToEachResultList &results,
4278 transform::TransformState &state) {
4279 rewriter.setInsertionPoint(target);
4280 auto maybeTransformed =
4282 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4283 return transposeConv2D(rewriter, op);
4284 })
4285 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4286 return transposeConv2D(rewriter, op);
4287 })
4288 .Default([&](Operation *op) {
4289 return rewriter.notifyMatchFailure(op, "not supported");
4290 });
4291 if (failed(maybeTransformed))
4292 return emitDefaultSilenceableFailure(target);
4293 // Handle to the new Conv2D operation with transposed filters
4294 results.push_back(*maybeTransformed);
4296}
4297
4298//===----------------------------------------------------------------------===//
4299// TransposeMatmulOp
4300//===----------------------------------------------------------------------===//
4301
4302DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4303 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4304 transform::ApplyToEachResultList &results,
4305 transform::TransformState &state) {
4306 rewriter.setInsertionPoint(target);
4307 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4308 auto maybeTransformed =
4310 .Case([&](linalg::MatmulOp op) {
4311 return transposeMatmul(rewriter, op, transposeLHS);
4312 })
4313 .Case([&](linalg::BatchMatmulOp op) {
4314 return transposeBatchMatmul(rewriter, op, transposeLHS);
4315 })
4316 .Default([&](Operation *op) { return failure(); });
4317 if (failed(maybeTransformed))
4318 return emitSilenceableFailure(target->getLoc()) << "not supported";
4319 // Handle to the new Matmul operation with transposed filters
4320 results.push_back(*maybeTransformed);
4322}
4323
4324//===----------------------------------------------------------------------===//
4325// InsertSliceToCopyOp
4326//===----------------------------------------------------------------------===//
4327template <typename OpTy>
4328static DiagnosedSilenceableFailure
4329doit(RewriterBase &rewriter, OpTy target,
4332 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4333 tensor::ParallelInsertSliceOp>() &&
4334 "wrong op type");
4335
4336 if (auto copySource =
4337 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4338 results.push_back(copySource);
4340 }
4341
4342 // If we are inside a `ParallelCombiningOp` region, temporarily set the
4343 // insertion point outside: only ops implementing ParallelCombiningOpInterface
4344 // are allowed in there.
4345 if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
4346 rewriter.setInsertionPoint(target->getParentOp());
4347
4348 Value extracted = tensor::ExtractSliceOp::create(
4349 rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
4350 target.getMixedSizes(), target.getMixedStrides());
4351 Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
4352 target.getSource(), extracted)
4353 .getResult(0);
4354 // Reset the insertion point.
4355 rewriter.setInsertionPoint(target);
4356 rewriter.replaceOpWithNewOp<OpTy>(
4357 target, copied, target.getDest(), target.getMixedOffsets(),
4358 target.getMixedSizes(), target.getMixedStrides());
4359
4360 results.push_back(copied.getDefiningOp());
4362}
4363
4364DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4365 transform::TransformRewriter &rewriter, Operation *targetOp,
4366 transform::ApplyToEachResultList &results,
4367 transform::TransformState &state) {
4368
4369 rewriter.setInsertionPoint(targetOp);
4370 if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4371 return doit(rewriter, target, results, state);
4372 if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4373 return doit(rewriter, target, results, state);
4374
4375 DiagnosedSilenceableFailure diag =
4376 emitSilenceableError()
4377 << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4378 diag.attachNote(targetOp->getLoc()) << "target op";
4379 return diag;
4380}
4381
4382//===----------------------------------------------------------------------===//
4383// MapCopyToThreadsOp
4384//===----------------------------------------------------------------------===//
4385
4386DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4387 transform::TransformRewriter &rewriter, Operation *target,
4388 transform::ApplyToEachResultList &results,
4389 transform::TransformState &state) {
4390 // Check if the op is supported.
4391 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4392 DiagnosedSilenceableFailure diag =
4393 emitSilenceableError()
4394 << "only linalg.copy and tensor.pad target ops are supported";
4395 diag.attachNote(target->getLoc()) << "target op";
4396 return diag;
4397 }
4398 assert(target->getNumResults() == 1 && "expected single result");
4399 auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
4400 if (!resultShapedType.hasStaticShape()) {
4401 DiagnosedSilenceableFailure diag =
4402 emitSilenceableError()
4403 << "only statically sized ops of rank <= 3 are supported";
4404 diag.attachNote(target->getLoc()) << "target op";
4405 return diag;
4406 }
4407
4408 // Conservatively set the minimum viable desired bitwidth alignment.
4409 int64_t desiredBitAlignment = getDesiredBitAlignment();
4410 int64_t eltBitwidth =
4411 resultShapedType.getElementType().getIntOrFloatBitWidth();
4412 if (desiredBitAlignment % eltBitwidth != 0) {
4413 desiredBitAlignment = eltBitwidth;
4414 }
4415
4416 gpu::CopyMappingInfo mapping(
4417 /*ctx=*/getContext(),
4418 /*totalNumThreads=*/getTotalNumThreads(),
4419 /*alignment=*/desiredBitAlignment,
4420 /*sizes=*/resultShapedType.getShape(),
4421 /*favorPredication=*/false,
4422 /*elementalBitwidth=*/
4423 resultShapedType.getElementType().getIntOrFloatBitWidth());
4424 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4425 DiagnosedSilenceableFailure diag =
4426 emitSilenceableError()
4427 << "too few threads to map copy op to threads on the most minor "
4428 "dimension, given alignment and vector size constraints, try "
4429 "smaller tile size of mapping to more threads";
4430 diag.attachNote(target->getLoc()) << "target op";
4431 return diag;
4432 }
4433
4434 // OpBuilder only used to compute attributes.
4435 OpBuilder b(getContext());
4436 scf::SCFTilingResult tilingResult;
4437 DiagnosedSilenceableFailure diag = tileToForallOpImpl(
4438 /*rewriter=*/rewriter,
4439 /*state=*/state,
4440 /*transformOp=*/*this,
4441 /*target=*/target,
4442 /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
4443 /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
4444 /*mapping=*/b.getArrayAttr(mapping.threadMapping),
4445 /*tilingResult=*/tilingResult);
4446 if (!diag.succeeded())
4447 return diag;
4448
4449 results.push_back(tilingResult.loops.front());
4450 for (auto op : tilingResult.tiledOps)
4451 results.push_back(op);
4453}
4454
4455//===----------------------------------------------------------------------===//
4456// WinogradConv2DOp
4457//===----------------------------------------------------------------------===//
4458
4459DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4460 transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4461 transform::ApplyToEachResultList &results,
4462 transform::TransformState &state) {
4463 rewriter.setInsertionPoint(target);
4464 FailureOr<Operation *> maybeTransformed = failure();
4465 bool supported = TypeSwitch<Operation *, bool>(target)
4466 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4467 maybeTransformed =
4468 winogradConv2D(rewriter, op, getFmr());
4469 return true;
4470 })
4471 .Default([&](Operation *op) { return false; });
4472
4473 if (!supported) {
4474 return emitSilenceableError()
4475 << "this operation is not supported to convert to Winograd Conv2D";
4476 }
4477
4478 if (failed(maybeTransformed)) {
4479 return emitSilenceableError() << "apply Winograd Conv2D failed";
4480 }
4481
4482 results.push_back(*maybeTransformed);
4484}
4485
4486DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4487 transform::TransformRewriter &rewriter, Operation *target,
4488 transform::ApplyToEachResultList &results,
4489 transform::TransformState &state) {
4490 rewriter.setInsertionPoint(target);
4491 FailureOr<Operation *> maybeTransformed = failure();
4492 bool supported =
4494 .Case([&](linalg::WinogradFilterTransformOp op) {
4495 maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
4496 return true;
4497 })
4498 .Case([&](linalg::WinogradInputTransformOp op) {
4499 maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
4500 return true;
4501 })
4502 .Case([&](linalg::WinogradOutputTransformOp op) {
4503 maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
4504 return true;
4505 })
4506 .Default(false);
4507
4508 if (!supported) {
4509 DiagnosedSilenceableFailure diag =
4510 emitSilenceableError()
4511 << "this operation is not supported to decompose into other operations";
4512 diag.attachNote(target->getLoc()) << "target op";
4513 return diag;
4514 }
4515
4516 if (failed(maybeTransformed)) {
4517 DiagnosedSilenceableFailure diag =
4518 emitSilenceableError() << "decompose Winograd operations failed";
4519 diag.attachNote(target->getLoc()) << "target op";
4520 return diag;
4521 }
4522
4523 results.push_back(*maybeTransformed);
4525}
4526
4527#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4528
4529#define GET_OP_CLASSES
4530#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
for(Operation *op :ops)
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized. TODO: Replace this with a g...
#define DOWNSCALE_NORMAL(a, b)
ArrayAttr()
static bool mayBeRead(OpOperand &operand)
Return true if the operand may be read from by its owner.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static bool sameOrEquivalentIterArg(Value src, Value dst)
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tileSizes, Type)
static DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static llvm::ManagedStatic< PassManagerOptions > options
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
Definition AffineExpr.h:68
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
UnitAttr getUnitAttr()
Definition Builders.cpp:98
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition Builders.cpp:306
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the error.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
bool isSet() const
Returns true if this insert point is set.
Definition Builders.h:337
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition Builders.h:316
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:457
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getOpResult(unsigned idx)
Definition Operation.h:421
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
void setOperand(unsigned idx, Value value)
Definition Operation.h:351
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition Operation.h:263
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
result_range getOpResults()
Definition Operation.h:420
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
user_range getUsers() const
Definition Value.h:218
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
Definition Padding.cpp:244
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition Tiling.cpp:857
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition Transforms.h:489
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:215
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition Tiling.cpp:236
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition Tiling.cpp:156
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition Tiling.cpp:106
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition Hoisting.cpp:89
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition Hoisting.cpp:198
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition Split.cpp:67
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition Tiling.cpp:262
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition SCF.cpp:744
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
SmallVector< IntTy > extractFromIntegerArrayAttr(Attribute attr)
Extract integer values from the assumed ArrayAttr of IntegerAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition Builders.h:285
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition Transforms.h:521
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition Transforms.h:409
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition Transforms.h:456
LinalgPromotionOptions & setAlignment(unsigned align)
Definition Transforms.h:433
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition Transforms.h:440
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition Transforms.h:398
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition Transforms.h:420
LinalgPromotionOptions & setUseOriginalSubviewSize(bool originalSize)
Definition Transforms.h:427
LinalgPromotionOptions & setUseAlloca(bool use)
Definition Transforms.h:446
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition Transforms.h:466