42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/ScopeExit.h"
44 #include "llvm/ADT/TypeSwitch.h"
45 #include "llvm/Support/DebugLog.h"
46 #include "llvm/Support/LogicalResult.h"
47 #include <type_traits>
53 #define DEBUG_TYPE "linalg-transforms"
60 template <
typename PatternTy,
typename... Args>
63 using OpTy =
typename llvm::function_traits<
64 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
65 auto op = dyn_cast<OpTy>(operation);
70 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
75 auto result = pattern.returningMatchAndRewrite(op, rewriter);
78 return cast<LinalgOp>(result->getOperation());
88 if (
auto attr = dyn_cast<Attribute>(ofr)) {
89 if (!isa<IntegerAttr>(attr))
90 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
91 result.push_back(ofr);
95 Value transformValue = cast<Value>(ofr);
96 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
98 if (params.size() != 1)
99 return transformOp.emitDefiniteFailure()
100 <<
"requires exactly one parameter associated";
101 result.push_back(params[0]);
105 auto payloadOps = state.getPayloadOps(transformValue);
106 if (!llvm::hasSingleElement(payloadOps)) {
108 transformOp.emitSilenceableError()
109 <<
"handle must be mapped to exactly one payload op";
111 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
118 transformOp.emitSilenceableError()
119 <<
"payload op must have exactly 1 index result";
139 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
141 for (
auto param : params) {
142 if (!isa<IntegerAttr>(param))
143 return transformOp.emitDefiniteFailure()
144 <<
"expected the parameter to be associated with an integer "
146 result.push_back(param);
151 for (
Operation *op : state.getPayloadOps(packedHandle)) {
152 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
154 transformOp.emitSilenceableError()
155 <<
"payload op must have exactly 1 index result";
156 diag.attachNote(op->getLoc())
157 <<
"has " << op->getNumResults() <<
" results";
160 result.push_back(op->getResult(0));
174 if (
auto attr = dyn_cast<Attribute>(paramOrHandle)) {
175 reified.push_back(cast<IntegerAttr>(attr).getInt());
177 }
else if (isa<ParamType>(cast<Value>(paramOrHandle).
getType())) {
179 if (params.size() != 1)
180 return transformOp.emitSilenceableError() <<
"expected a single param";
182 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
186 Value handle = cast<Value>(paramOrHandle);
187 if (!isa<TransformHandleTypeInterface>(handle.
getType()))
188 return transformOp.emitSilenceableError() <<
"unexpected value handle";
189 auto payload = state.getPayloadOps(handle);
190 if (!llvm::hasSingleElement(payload))
191 return transformOp.emitSilenceableError()
192 <<
"requires param or handle that is mapped to 1 payload op";
194 Operation *paramOrHandlePayloadOp = *payload.begin();
197 return transformOp.emitSilenceableError()
198 <<
"requires param or handle to be result of op with 1 index "
204 return transformOp.emitSilenceableError()
205 <<
"requires param or handle to be the result of a constant like "
208 reified.push_back(attr.getInt());
217 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
222 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
227 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
232 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
238 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
241 options.rankReductionStrategy =
246 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
251 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
256 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
261 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
266 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
275 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
280 resultTypes.push_back(b.
getType<transform::AnyValueType>());
281 resultTypes.push_back(b.
getType<transform::AnyOpType>());
282 return build(b, result,
288 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
291 int64_t memorySpace) {
293 resultTypes.push_back(b.
getType<transform::AnyValueType>());
294 resultTypes.push_back(b.
getType<transform::AnyOpType>());
295 return build(b, result,
311 void notifyOperationInserted(
Operation *op,
313 ForwardingListener::notifyOperationInserted(op, previous);
315 if (previous.
isSet())
317 auto inserted = newOps.insert(op);
319 assert(inserted.second &&
"expected newly created op");
322 void notifyOperationErased(
Operation *op)
override {
323 ForwardingListener::notifyOperationErased(op);
337 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
338 NewOpsListener newOpsListener(previousListener);
342 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
345 }
else if (getMemcpyOp() ==
"memref.copy") {
348 }
else if (getMemcpyOp() ==
"linalg.copy") {
352 llvm_unreachable(
"invalid memcpy op");
354 if (getAllocOp() ==
"memref.alloc") {
357 }
else if (getAllocOp() ==
"memref.alloca") {
361 llvm_unreachable(
"invalid alloc op");
363 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
364 options.emitDealloc = getEmitDealloc();
368 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
370 for (
Operation *op : state.getPayloadOps(getTarget())) {
375 <<
"failed to bufferize operation";
376 diag.attachNote(op->
getLoc()) <<
"target payload op";
379 allocatedBuffers.push_back(buffer);
383 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
384 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
388 void transform::BufferizeToAllocationOp::getEffects(
390 if (getBufferizeDestinationOnly()) {
402 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
403 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
404 return emitOpError() <<
"unsupported memcpy op";
405 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
406 return emitOpError() <<
"unsupported alloc op";
419 #define DOWNSCALE(trans) \
421 FailureOr<LinalgOp> res = tryApply<trans>(target); \
422 if (succeeded(res)) { \
423 results.push_back(*res); \
424 return DiagnosedSilenceableFailure::success(); \
428 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
429 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
442 #undef DOWNSCALE_NORMAL
443 #undef DOWNSCALE_CALL
445 return emitDefaultSilenceableFailure(target);
459 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
460 if (!decomposableOp) {
462 "payload is not a decomposable op"));
463 return emitDefaultSilenceableFailure(target);
466 FailureOr<SmallVector<Value>> maybeNewResults =
467 decomposableOp.decomposeOperation(rewriter);
468 if (
failed(maybeNewResults))
469 return emitDefaultSilenceableFailure(target);
471 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
472 for (
Value val : *maybeNewResults) {
473 Operation *definition = val.getDefiningOp();
484 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
491 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
495 options.allowReturnAllocsFromLoops =
true;
497 for (
Operation *target : state.getPayloadOps(getTarget())) {
501 <<
"failed to analyze op";
503 rewriter, target, state)))
505 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
516 template <
typename Range>
520 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
526 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
527 if (!tilingInterfaceOp)
528 return transformOp->
emitError(
"only TilingInterface ops are supported");
531 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
532 applyFn(tilingInterfaceOp);
538 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
539 for (
Operation *toReplace : opsToReplace) {
540 for (
OpResult res : toReplace->getResults())
541 if (
auto replacement = tiledResults->replacements.lookup(res))
543 if (toReplace->use_empty()) {
549 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
550 assert(tiledResults->loops.size() == numLoops &&
551 "Mismatched number of loops, tile and fuse transform should have "
553 for (
unsigned int i = 0; i < numLoops; ++i)
554 loopOps[i].push_back(tiledResults->loops[i]);
557 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
558 for (
unsigned int i = 0; i < numLoops; ++i)
559 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
571 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
577 tilingOptions = tilingOptions.
setTileSizes(tileSizesOfr);
581 if (getApplyCleanup()) {
584 tensor::ExtractSliceOp::getCanonicalizationPatterns(
patterns, context);
591 rewriter, getOperation(), state.getPayloadOps(getTarget()),
592 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
593 [&](TilingInterface tilingInterfaceOp)
594 -> FailureOr<scf::SCFTileAndFuseResult> {
595 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
604 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
605 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
606 if (!std::is_permutation(sequence.begin(), sequence.end(),
607 permutation.begin(), permutation.end())) {
608 return emitOpError() <<
"expects interchange to be a permutation, found "
609 << getTileInterchange();
614 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
615 if (numExpectedLoops != getNumResults() - 1)
616 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
625 void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
628 Value containingOp) {
631 result.
addTypes({resultType, resultType});
647 (domInfo.
dominates(containingOp, user))) {
648 dominatedUsers.insert(user);
651 if (dominatedUsers.empty())
655 auto forallOp = cast<scf::ForallOp>(containingOp);
661 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
666 newOuts.push_back(outputs[resultNumber]);
669 auto newforallOp = scf::ForallOp::create(
670 rewriter, loc, forallOp.getMixedLowerBound(),
671 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
672 forallOp.getMapping());
674 newforallOp.getRegion().takeBody(forallOp.getRegion());
679 newforallOp.getBody()->addArgument(newOuts.back().getType(),
680 newOuts.back().getLoc());
681 auto bbArgs = newforallOp.getBody()->getArguments();
684 Operation *op = use.getOwner();
685 return newforallOp->isProperAncestor(op);
689 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
691 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; }));
692 Operation *firstYieldOp = yieldingOps.front();
695 Value dst = newforallOp.getRegionIterArgs().back();
697 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->
getLoc(), src,
698 dst, offsets, sizes, strides);
702 newforallOp->getResult(result.index()));
705 newforallOp->getResults().back(),
707 Operation *user = use.getOwner();
708 return dominatedUsers.contains(user);
722 destWorklist.push_back(dst);
724 while (!destWorklist.empty()) {
725 Value currentDst = destWorklist.pop_back_val();
729 if (src == currentDst)
734 auto bbArg = dyn_cast<BlockArgument>(currentDst);
738 Block *parentBlock = bbArg.getOwner();
739 assert(parentBlock &&
"unlinked block argument");
742 assert(parentOp &&
"expected block argument with parent operation");
745 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
749 for (
auto innerIterArg : parentLoop.getRegionIterArgs()) {
751 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
752 Value loopBlockArgument =
754 destWorklist.push_back(loopBlockArgument);
767 static std::tuple<SmallVector<Operation *>,
Operation *>
770 LDBG() <<
"Try to fuse a direct extract use";
771 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
772 if (!tileableProducer) {
774 <<
"producer is not a TileableInterface: " << *producerOp;
781 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
782 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
783 return sliceOp && containingOp->isProperAncestor(sliceOp);
787 if (it == tileableProducer->getUsers().end()) {
788 diag.attachNote(tileableProducer->getLoc())
789 <<
"could not find fusion opportunity for: " << *tileableProducer;
792 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
805 if (LoopLikeOpInterface containerLoop =
806 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
813 cast<DestinationStyleOpInterface>(
clone).getDpsInitsMutable()) {
814 Value producerOperand =
817 containerLoop.getRegionIterArgs()) {
818 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
819 Value consumerOperand =
820 containerLoop->getOperand(bbArg->getOperandNumber());
822 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
823 initOperandPtr.set(containerIterArg);
829 tileableProducer = dyn_cast<TilingInterface>(
clone);
833 int64_t resultNumber =
834 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
835 LDBG() <<
"resultNumber: " << resultNumber;
840 FailureOr<TilingResult> tileAndFuseResult =
841 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
844 if (
failed(tileAndFuseResult)) {
845 diag.attachNote(tileableProducer->getLoc())
846 <<
"failed to tile producer op: " << *tileableProducer;
851 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
852 LDBG() <<
"tiledProducer: " << *tiledOp;
857 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
858 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
859 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
860 if (
failed(maybeRankReduced)) {
862 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
863 << tileAndFuseResult->tiledValues[0]
864 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
867 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
871 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
872 resultNumber, offsets, sizes);
875 if (dyn_cast<LoopLikeOpInterface>(containingOp))
876 rewriter.
eraseOp(tileableProducer);
878 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
891 LDBG() <<
"Try to fuse an extract use through block argument";
893 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
894 if (!tileableProducer) {
896 <<
"producer is not a TileableInterface: " << *producerOp;
901 scf::ForallOp forallOp;
902 auto itProducerUses =
903 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
904 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
908 if (!forallOp || forallOp != containingOp) {
909 diag.attachNote(tileableProducer->getLoc())
910 <<
"could not find a use by the containing op: " << *tileableProducer;
925 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
926 return sliceOp && containingOp->isProperAncestor(sliceOp);
930 if (itBBArgUsers == bbArg.
getUsers().end()) {
932 <<
"could not find fusion opportunity for bbArg: " << bbArg;
935 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
943 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
944 LDBG() <<
"resultNumber: " << resultNumber;
949 rewriter, tileableProducer->getLoc(), tileableProducer,
950 destinationTensors))) {
951 diag.attachNote(tileableProducer->getLoc())
952 <<
"failed to get destination tensors for: " << *tileableProducer;
957 bvm.
map(destinationTensors[resultNumber], bbArg);
958 auto tileableProducerClone =
959 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
961 llvm::make_scope_exit([&]() { rewriter.
eraseOp(tileableProducerClone); });
964 FailureOr<TilingResult> tileAndFuseResult =
965 tileableProducerClone.generateResultTileValue(
966 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
967 sliceOpToTile.getMixedSizes());
968 if (
failed(tileAndFuseResult)) {
969 diag.attachNote(tileableProducer->getLoc())
970 <<
"failed to tile producer op: " << *tileableProducer;
975 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
976 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
977 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
978 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
979 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
984 destinationTensors.front());
987 return tileAndFuseResult->tiledOps;
993 LDBG() <<
"Try to fuse an use by cloning";
998 for (
OpOperand &use : result.getUses()) {
1000 uses.push_back(&use);
1005 if (containingOp == use.getOwner()) {
1007 <<
"producer op use by containing op cannot be fused by cloning";
1015 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
1024 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
1025 "Parallel insert slice is not a valid clone destination");
1026 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
1027 LDBG() <<
"resultNumber: " << resultNumber;
1031 fusedOp = rewriter.
clone(*producerOp);
1033 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1038 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1048 auto producerOps = state.getPayloadOps(getProducerOp());
1049 auto containingOps = state.getPayloadOps(getContainingOp());
1050 if (!llvm::hasSingleElement(containingOps)) {
1052 <<
"requires exactly one containing_op handle (got "
1053 << llvm::range_size(containingOps) <<
")";
1055 Operation *containingOp = *containingOps.begin();
1058 if (std::empty(producerOps)) {
1060 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1067 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1068 for (
const auto &it :
enumerate(remainingProducers)) {
1071 int64_t numUsesInContainingOp =
1073 return containingOp->isAncestor(op);
1078 if (numUsesInContainingOp > 0) {
1079 if (numUsesInContainingOp == 1)
1080 remainingProducers.erase(remainingProducers.begin() + it.index());
1087 while (!remainingProducers.empty()) {
1088 auto nextProducer = getNextProducer();
1089 if (
failed(nextProducer)) {
1091 <<
"could not find next producer to fuse into container";
1092 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
1100 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
1107 auto [tiledOps, newContainingOp] =
1109 if (!tiledOps.empty()) {
1110 LDBG() <<
"\nFused a direct extract use\n" << *containingOp;
1111 fusedOps.append(tiledOps);
1112 if (newContainingOp) {
1120 LogicalResult replacementStatus =
1123 (void)replacementStatus;
1124 assert(succeeded(replacementStatus) &&
1125 "unable to update transform state mapping");
1126 rewriter.
eraseOp(containingOp);
1127 containingOp = newContainingOp;
1134 rewriter,
diag, producerOp, containingOp);
1135 if (!tiledContainingOpOperand.empty()) {
1136 LDBG() <<
"\nFused an extract use through block argument\n"
1138 fusedOps.append(tiledContainingOpOperand);
1145 LDBG() <<
"\nFused an use by cloning\n" << *containingOp;
1146 fusedOps.push_back(cloned);
1152 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1153 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1157 void transform::FuseIntoContainingOp::getEffects(
1175 if (isa<GenericOp>(target)) {
1181 if (succeeded(
generic)) {
1182 results.
push_back(generic->getOperation());
1185 return emitDefaultSilenceableFailure(target);
1198 if (!isa<GenericOp>(target)) {
1203 FailureOr<LinalgOp> named =
1205 if (succeeded(named)) {
1206 results.
push_back(named->getOperation());
1209 return emitDefaultSilenceableFailure(target);
1223 if (interchangeVector.empty()) {
1228 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1229 if (interchangeVector.size() != numLoops) {
1230 return emitSilenceableError()
1231 << getIteratorInterchangeAttrName() <<
" has length ("
1232 << interchangeVector.size()
1233 <<
") different from the number of loops in the target operation ("
1246 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1247 if (!std::is_permutation(sequence.begin(), sequence.end(),
1248 permutation.begin(), permutation.end())) {
1249 return emitOpError()
1250 <<
"expects iterator_interchange to be a permutation, found "
1251 << getIteratorInterchange();
1266 if (!isa<linalg::CopyOp>(targetOp)) {
1268 emitSilenceableError() <<
"only linalg.copy target ops are supported";
1269 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1273 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1274 if (!copyOp.hasPureBufferSemantics()) {
1276 emitSilenceableError()
1277 <<
"cannot transform a linalg.copy on tensors into a memref.copy";
1278 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1284 assert(inputs.size() == 1 &&
"expected linalg copy op with one input");
1285 assert(outputs.size() == 1 &&
"expected memref copy op with one output");
1286 Value input = inputs.front();
1287 Value output = outputs.front();
1292 if (!isa<ShapedType>(input.
getType())) {
1294 emitSilenceableError()
1295 <<
"cannot transform a linalg.copy which input has no shape";
1296 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1301 assert(isa<ShapedType>(output.
getType()));
1303 if (cast<ShapedType>(input.
getType()).getElementType() !=
1304 cast<ShapedType>(output.
getType()).getElementType()) {
1306 emitSilenceableError()
1307 <<
"cannot transform a linalg.copy with different source and "
1308 "destination element types ";
1309 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1330 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1331 FailureOr<LowerPackResult> res =
1332 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1335 <<
"cannot lower to pad + expand + transpose";
1338 transformResults.
push_back(res->expandShapeOp);
1339 transformResults.
push_back(res->transposeOp);
1352 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1353 FailureOr<LowerUnPackOpResult> res =
1354 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1357 emitSilenceableError()
1358 <<
"cannot lower to transpose + collapse + extract";
1359 diag.attachNote(target->getLoc()) <<
"target payload op";
1362 transformResults.
push_back(res->emptyOp);
1363 transformResults.
push_back(res->transposeOp);
1364 transformResults.
push_back(res->collapseShapeOp);
1365 transformResults.
push_back(res->extractSliceOp);
1395 if (getOps().has_value())
1396 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1398 auto payloadOps = state.getPayloadOps(getTarget());
1399 if (!llvm::hasSingleElement(payloadOps)) {
1404 bool incorrectNumOperandTypes =
false;
1411 if (getInterface().has_value()) {
1412 auto iface = getInterface().value();
1413 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1416 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1417 !isa<TilingInterface>(op))
1419 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1420 !isa<LoopLikeOpInterface>(op))
1425 if (getOpAttrs().has_value()) {
1426 DictionaryAttr opAttrs = getOpAttrs().value();
1428 if (attr.getName() == getInterfaceAttrName() ||
1429 attr.getName() == getOpsAttrName())
1431 if (!op->
hasAttr(attr.getName()))
1433 if (op->
getAttr(attr.getName()) != attr.getValue())
1438 if (getFilterResultType().has_value()) {
1439 Type t = getFilterResultType().value();
1444 if (getFilterOperandTypes().has_value()) {
1445 mlir::ArrayAttr types = getFilterOperandTypes().value();
1448 if (types.size() == 1) {
1451 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1452 Type t = cast<::mlir::Type>(typeattr.getValue());
1454 [&](
Type operandType) { return operandType == t; }))
1459 if (types.size() != operandTypes.size()) {
1460 incorrectNumOperandTypes =
true;
1464 for (
auto [attr, operandType] :
1465 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1466 auto typeattr = cast<mlir::TypeAttr>(attr);
1467 Type type = cast<::mlir::Type>(typeattr.getValue());
1469 if (type != operandType)
1480 (*payloadOps.begin())->
walk(matchFun);
1481 if (incorrectNumOperandTypes)
1483 "type, then it must contain as much types as "
1484 "the number of operands in the target ops");
1485 results.
set(cast<OpResult>(getResult()), res);
1500 Type &targetType,
Type &lowSizeType,
1502 Type &splitPointType) {
1503 FunctionType funcType;
1508 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1509 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1510 "argument and one result";
1512 targetType = funcType.getInput(0);
1513 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1521 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1522 if (target.hasDynamicShape()) {
1523 auto diag = emitSilenceableError()
1524 <<
"cannot compute parametric tile sizes for dynamically "
1525 "shaped payload op";
1526 diag.attachNote(target->getLoc()) <<
"payload op";
1531 target, getDimension(), getTargetSize(), getDivisor());
1533 return emitSilenceableError()
1534 <<
"failed to compute multi-size tiling sizes";
1537 Builder builder(target.getContext());
1538 results.
assign(llvm::map_range(
1540 spec->lowTileSize * spec->lowTripCount}),
1541 [&builder,
this](int64_t value) {
1553 builder, target, getDimension(), targetSize, divisor);
1555 return emitSilenceableError() <<
"could not generate tile size computation";
1562 {spec->lowTileSize, spec->lowTripCount});
1563 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1564 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1565 assert(lowTileSize && highTileSize && splitPoint &&
1566 "tile sizes are not produced by operations");
1574 void transform::MultiTileSizesOp::getEffects(
1578 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1587 return emitOpError() <<
"expects all results type to be the same";
1607 builder.
getContext(), GenericOp::getOperationName());
1608 build(builder, result,
1617 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1624 auto targetOps = state.getPayloadOps(getTarget());
1626 if (std::empty(targetOps)) {
1627 transformResults.
set(cast<OpResult>(getPackedOp()),
1632 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1633 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1634 return emitSilenceableError()
1635 <<
"requires target to map to exactly 1 LinalgOp (got "
1636 << llvm::range_size(targetOps) <<
")";
1639 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1640 return emitSilenceableError()
1641 <<
"requires number of packed sizes match the number of loops ("
1642 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1649 state, *
this, packedSizes, getMixedPackedSizes());
1652 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1656 transformResults.
set(cast<OpResult>(getPackedOp()),
1657 {maybeResult->packedLinalgOp.getOperation()});
1661 void transform::PackOp::getEffects(
1675 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1676 <<
" is not a valid permutation";
1679 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1680 for (
auto [s, nmo] :
1681 llvm::zip_equal(getMixedMatmulPackedSizes(),
1682 getMatmulPaddedSizesNextMultipleOf())) {
1685 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1686 return emitOpError() <<
"at most one of the packed_size and the "
1687 "padded_sizes_next_multiple_of can be nonzero "
1688 "for the matmul strategy";
1700 for (
Operation *op : state.getPayloadOps(getTarget())) {
1701 auto linalgOp = dyn_cast<LinalgOp>(op);
1712 getMixedMatmulPackedSizes(),
1714 getMatmulPaddedSizesNextMultipleOf(),
1715 getMatmulInnerDimsOrder());
1716 if (succeeded(packResult)) {
1717 results.push_back(packResult->packedLinalgOp);
1720 results.push_back(linalgOp);
1722 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1728 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1732 void transform::PackGreedilyOp::getEffects(
1746 return emitOpError() << getInnerPermAttrName()
1747 <<
" is not a valid permutation";
1750 return emitOpError() << getOuterPermAttrName()
1751 <<
" is not a valid permutation";
1753 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1754 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1755 <<
" or " << getOuterPermAttrName()
1756 <<
" must be specified";
1762 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1772 template <
typename RelayoutOpTy>
1775 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1777 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1778 "applies to only pack or unpack operations");
1779 if (!op || permutation.empty())
1781 size_t innerRank = op.getInnerDimsPos().size();
1782 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1786 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1787 return permutation.size() == op.getSourceRank() &&
1790 return permutation.size() == op.getDestRank() &&
1798 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1799 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1801 if (std::empty(packOrUnpackOps)) {
1802 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1803 transformResults.
set(cast<OpResult>(getPackOp()), {});
1804 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1810 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1811 !llvm::hasSingleElement(linalgOps)) {
1812 return emitSilenceableError()
1813 <<
"requires target to map to exactly 1 "
1814 "packing op and 1 packed op ("
1815 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
1816 << llvm::range_size(linalgOps) <<
")";
1820 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1821 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1822 if ((!packOp && !unPackOp)) {
1823 return emitSilenceableError() <<
"requires target to map to a "
1824 "linalg.pack or linalg.unpack";
1826 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1827 if (!linalgOpTarget)
1828 return emitSilenceableError() <<
"requires a LinalgOp target";
1832 if (packOp && packOp.getResult().hasOneUse())
1833 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1835 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1836 if (linalgOp != linalgOpTarget) {
1838 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
1839 : StringLiteral{
"not produced by the LinalgOp target"};
1840 return emitSilenceableError() << errorMsg;
1846 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
1847 OpOperand *packUse = linalgOp.getDpsInitOperand(
1848 cast<OpResult>(unPackOp.getSource()).getResultNumber());
1850 if (!packOp || !packOp.getResult().hasOneUse())
1851 return emitSilenceableError() <<
"could not find matching pack op";
1855 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1857 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1858 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1859 ? StringLiteral{
"invalid outer_perm"}
1860 : StringLiteral{
"invalid inner_perm"};
1864 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1865 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
1871 assert(packOp && linalgOp &&
"unexpected null op");
1875 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1877 assert(succeeded(res) &&
"unexpected packTranspose failure");
1880 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1881 transformResults.
set(cast<OpResult>(getPackedOp()),
1882 {res->transposedLinalgOp});
1884 transformResults.
set(cast<OpResult>(getUnPackOp()),
1885 {res->transposedUnPackOp});
1887 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1902 StringRef copyBackOp,
1903 bool usePrescribedTensorShapes) {
1913 (padToMultipleOf.empty()
1915 : b.getDenseI64ArrayAttr(padToMultipleOf)),
1916 b.getI64ArrayAttr(nofoldFlags),
1917 b.getArrayAttr(transposePaddings),
1918 b.getStringAttr(copyBackOp),
1920 usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
1928 StringRef copyBackOp,
1929 bool usePrescribedTensorShapes) {
1934 staticPadToMultipleOf);
1941 dynamicPadToMultipleOf,
1942 staticPadToMultipleOf,
1946 usePrescribedTensorShapes);
1949 void PadOp::getEffects(
1959 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1966 auto transformOp = cast<TransformOpInterface>(getOperation());
1969 for (
Operation *target : state.getPayloadOps(getTarget())) {
1970 auto linalgTarget = dyn_cast<LinalgOp>(target);
1971 if (!linalgTarget) {
1972 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
1973 diag.attachNote(target->
getLoc()) <<
"target op";
1979 for (int64_t packPadding :
1980 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1981 nofoldFlags.push_back(
static_cast<bool>(packPadding));
1985 for (
auto const &[untypedAttr, elementOrTensorType] :
1986 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1988 if (isa<ub::PoisonAttr>(untypedAttr)) {
1989 paddingValues.push_back(untypedAttr);
1992 auto attr = dyn_cast<TypedAttr>(untypedAttr);
1994 emitOpError(
"expects padding values to be typed attributes or poison");
1999 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2003 if (!parsedAttr || parsedAttr.getType() != elementType) {
2004 auto diag = this->emitOpError(
"expects a padding that parses to ")
2005 << elementType <<
", got " << untypedAttr;
2006 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2009 paddingValues.push_back(parsedAttr);
2013 if (attr.getType() != elementType) {
2014 auto diag = this->emitOpError(
"expects a padding value of type ")
2015 << elementType <<
", got " << attr;
2016 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2019 paddingValues.push_back(attr);
2024 for (
Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2025 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2026 cast<ArrayAttr>(transposeVector)));
2031 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2035 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2038 if (padToMultipleOf.empty())
2042 options.padToMultipleOf = padToMultipleOf;
2043 options.paddingValues = paddingValues;
2044 options.nofoldFlags = nofoldFlags;
2045 if (getCopyBackOp() ==
2046 bufferization::MaterializeInDestinationOp::getOperationName()) {
2049 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2051 }
else if (getCopyBackOp() == kCopyOpNone) {
2054 llvm_unreachable(
"unsupported copy_back op");
2057 bool irChanged =
false;
2058 if (getUsePrescribedTensorShapes() &&
2059 linalgTarget.hasPureTensorSemantics()) {
2062 for (
OpOperand &operand : linalgTarget->getOpOperands()) {
2063 for (
auto [i, dim] :
llvm::enumerate(linalgTarget.getShape(&operand))) {
2064 if (ShapedType::isStatic(dim))
2066 options.setSizeToPadTo(operand.getOperandNumber(), i,
2068 operand.get().getLoc(),
2078 replacements, newPadOps))) {
2081 diag.attachNote(target->
getLoc()) <<
"target op";
2084 auto diag = emitSilenceableError() <<
"failed to pad op";
2085 diag.attachNote(target->
getLoc()) <<
"target op";
2094 rewriter.
replaceOp(linalgTarget, replacements);
2095 paddedOps.push_back(paddedOp);
2096 padOps.append(newPadOps.begin(), newPadOps.end());
2098 for (
Value v : replacements) {
2099 Operation *copyBackOp = v.getDefiningOp();
2100 if (!llvm::is_contained(copyBackOps, copyBackOp))
2101 copyBackOps.push_back(copyBackOp);
2106 results.
set(cast<OpResult>(getPadded()), paddedOps);
2107 results.
set(cast<OpResult>(getPad()), padOps);
2108 results.
set(cast<OpResult>(getCopy()), copyBackOps);
2114 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2115 if (any_of(nofoldFlags, [](int64_t packPadding) {
2116 return packPadding != 0 && packPadding != 1;
2118 return emitOpError()
2119 <<
"expects nofold_flags to contain booleans (0/1), found "
2120 << getNofoldFlags();
2124 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2125 if (any_of(paddingDimensions,
2126 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
2127 return emitOpError() <<
"expects padding_dimensions to contain positive "
2129 << getPaddingDimensions();
2131 if (!getMixedPadToMultipleOf().empty()) {
2132 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2133 return emitOpError() <<
"expects as many multiples as padding_dimensions";
2136 ArrayAttr transposes = getTransposePaddings();
2139 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2140 if (!std::is_permutation(sequence.begin(), sequence.end(),
2141 transpose.begin(), transpose.end())) {
2142 return emitOpError()
2143 <<
"expects transpose_paddings to be a permutation, found "
2147 if (getCopyBackOp() !=
2148 bufferization::MaterializeInDestinationOp::getOperationName() &&
2149 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2150 getCopyBackOp() != kCopyOpNone)
2151 return emitOpError() <<
"invalid copy_back_op";
2159 void transform::PadTilingInterfaceOp::build(
OpBuilder &b,
2163 bool padToMultipleOf) {
2173 : b.getDenseI64ArrayAttr(paddingSizes)),
2175 padToMultipleOf ? b.getUnitAttr() : nullptr);
2178 void transform::PadTilingInterfaceOp::build(
2185 staticPaddingSizes);
2191 dynamicPaddingSizes,
2196 void transform::PadTilingInterfaceOp::getEffects(
2205 transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2207 return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2216 for (
Operation *target : state.getPayloadOps(getTarget())) {
2217 auto targetOp = dyn_cast<TilingInterface>(target);
2219 auto diag = emitSilenceableError() <<
"expected TilingInterface target";
2220 diag.attachNote(target->
getLoc()) <<
"target op";
2227 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2228 auto diag = emitSilenceableError() <<
"only IndexingMapOpInterface ops "
2230 diag.attachNote(target->
getLoc()) <<
"target op";
2236 for (
auto const &[untypedAttr, elementOrTensorType] :
2237 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2238 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2241 if (isa<ub::PoisonAttr>(untypedAttr)) {
2242 paddingValues.push_back(untypedAttr);
2246 emitOpError(
"expects padding values to be typed attributes or poison");
2250 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2254 if (!parsedAttr || parsedAttr.getType() != elementType) {
2255 auto diag = this->emitOpError(
"expects a padding that parses to ")
2256 << elementType <<
", got " << attr;
2257 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2260 paddingValues.push_back(parsedAttr);
2264 if (attr.getType() != elementType) {
2265 auto diag = this->emitOpError(
"expects a padding value of type ")
2266 << elementType <<
", got " << attr;
2267 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2270 paddingValues.push_back(attr);
2274 TilingInterface paddedOp;
2276 options.setPaddingValues(paddingValues)
2277 .setPaddingSizes(getMixedPaddingSizes())
2278 .setPadToMultipleOf(getPadToMultipleOf());
2283 rewriter, cast<TilingInterface>(targetOp.getOperation()),
options,
2285 if (
failed(maybePaddedOp)) {
2286 auto diag = emitSilenceableError() <<
"failed to pad op";
2287 diag.attachNote(target->
getLoc()) <<
"target op";
2292 paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
2293 padOps.append(newPadOps.begin(), newPadOps.end());
2296 results.
set(cast<OpResult>(getPadded()), paddedOps);
2297 results.
set(cast<OpResult>(getPad()), padOps);
2311 auto targetOps = state.getPayloadOps(getTarget());
2312 auto loopOps = state.getPayloadOps(getLoop());
2313 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2315 <<
"requires exactly one target and one loop handle (got "
2316 << llvm::range_size(targetOps) <<
" and "
2317 << llvm::range_size(loopOps) <<
")";
2320 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2321 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2322 if (!padOp || !loopOp)
2325 FailureOr<linalg::detail::PackingResult> result =
2331 if (result->clonedLoopIvs.empty()) {
2332 transformResults.
set(cast<OpResult>(getPackingLoop()),
2333 {result->hoistedPadOp.getOperation()});
2336 auto outerPackedLoop =
2338 transformResults.
set(cast<OpResult>(getPackingLoop()),
2339 {outerPackedLoop.getOperation()});
2345 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2346 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2348 return emitOpError() <<
"expects transpose to be a permutation, found "
2354 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2364 tensor::PadOp target,
2367 tensor::PadOp hoistedPadOp;
2369 FailureOr<Value> result =
2371 hoistedPadOp, transposeOps);
2372 if (succeeded(result)) {
2382 return emitDefaultSilenceableFailure(target);
2387 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2388 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2390 return emitOpError() <<
"expects transpose to be a permutation, found "
2406 if (!getOperandsToPromote().empty())
2408 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2409 if (getUseFullTilesByDefault())
2411 getUseFullTilesByDefault());
2412 if (getUseOriginalSubviewSize())
2416 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2417 if (!getUseFullTileBuffers().empty())
2419 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2420 if (getAlignment().has_value())
2421 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2422 if (getMemorySpace().has_value())
2423 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2425 if (getMapping().has_value()) {
2427 auto mapping = *getMapping();
2428 if (mapping.size() > 1)
2429 return emitDefaultDefiniteFailure(target);
2431 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2433 if (addressSpace.getAddressSpace() ==
2434 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2441 }
else if (addressSpace.getAddressSpace() ==
2442 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2450 return emitDefaultDefiniteFailure(target);
2455 return emitDefaultDefiniteFailure(target);
2458 FailureOr<LinalgOp> res =
promoteSubViews(rewriter, target, promotionOptions);
2460 return emitDefaultDefiniteFailure(target);
2473 auto payload = state.getPayloadOps(getTarget());
2477 if (target->getNumOperands() > 0)
2480 target->getNumRegions() > 0)
2482 <<
"expected target that is isolated from above";
2486 Operation *pattern = &getBodyRegion().front().front();
2489 if (getOperation()->isAncestor(target))
2494 replacements.push_back(replacement);
2496 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2500 void transform::ReplaceOp::getEffects(
2508 if (!getBodyRegion().hasOneBlock())
2509 return emitOpError() <<
"expected one block";
2510 if (std::distance(getBodyRegion().front().begin(),
2511 getBodyRegion().front().end()) != 1)
2512 return emitOpError() <<
"expected one operation in block";
2513 Operation *replacement = &getBodyRegion().front().front();
2516 <<
"expected replacement without operands";
2520 <<
"expect op that is isolated from above";
2538 target.createFlatListOfOperandDims(b, loc);
2539 AffineMap map = target.getShapesToLoopsMap();
2554 FailureOr<scf::SCFTilingResult> maybeTilingResult =
tileUsingSCF(
2555 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2556 if (
failed(maybeTilingResult))
2557 return emitDefaultDefiniteFailure(target);
2559 if (target->getNumResults())
2560 rewriter.
replaceOp(target, maybeTilingResult->replacements);
2564 results.
reserve(maybeTilingResult->tiledOps.size());
2565 for (
Operation *tiled : maybeTilingResult->tiledOps)
2579 for (
Operation *target : state.getPayloadOps(getTarget())) {
2580 auto tilingOp = dyn_cast<TilingInterface>(*target);
2583 emitSilenceableError()
2584 <<
"expected the payload to implement TilingInterface";
2585 diag.attachNote(target->getLoc()) <<
"payload op";
2589 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2591 if (
failed(generatedLoops))
2592 return emitDefaultDefiniteFailure(target);
2593 for (scf::ForOp &loop : *generatedLoops) {
2594 loops.push_back(loop.getOperation());
2598 results.
set(cast<OpResult>(getResult()), loops);
2607 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2612 FailureOr<Operation *> maybeResult =
2614 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2615 [&rewriter](
auto op) {
2619 return emitDefaultSilenceableFailure(target);
2633 llvm::to_vector(state.getPayloadOps(getTarget()));
2635 bool isMultiwaySplit = getMultiway();
2637 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2639 <<
"requires exactly one target when "
2640 "multiway split is enabled (got "
2641 << llvm::range_size(payload) <<
")";
2646 if (!isMultiwaySplit)
2647 chunkSizes.reserve(payload.size());
2649 if (getDynamicChunkSizes()) {
2651 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2652 chunkSizes = llvm::to_vector(llvm::map_range(
2653 state.getPayloadOps(getDynamicChunkSizes()), [&](
Operation *op) {
2656 diag = emitSilenceableError()
2657 <<
"expected dynamic split point handle to point to a "
2658 "single-result index-typed op";
2659 diag.attachNote(op->getLoc()) <<
"dynamic split point";
2664 chunkSizes = llvm::to_vector(
2665 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2668 if (
diag.isSilenceableFailure())
2673 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2675 <<
"expected the dynamic split point handle to point to as "
2677 << chunkSizes.size() <<
") as the target handle ("
2678 << payload.size() <<
")";
2681 chunkSizes.resize(payload.size(),
2685 auto checkStructuredOpAndDimensions =
2688 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2689 diag.attachNote(loc) <<
"target op";
2693 if (getDimension() >= linalgOp.getNumLoops()) {
2694 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2695 <<
" does not exist in target op";
2696 diag.attachNote(loc) <<
"target op";
2702 auto checkFailureInSplitting =
2706 diag.attachNote(loc) <<
"target op";
2713 if (isMultiwaySplit) {
2716 TilingInterface head, tail;
2719 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2723 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2724 if (
diag.isSilenceableFailure())
2730 target = tail.getOperation();
2735 linalgOp = cast<LinalgOp>(target);
2740 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2741 getDimension(), chunkSize);
2745 checkFailureInSplitting(!head && !tail, loc);
2746 if (
diag.isDefiniteFailure())
2749 opList.push_back(head.getOperation());
2754 opList.push_back(tail.getOperation());
2760 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2763 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2765 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2767 if (
diag.isSilenceableFailure())
2771 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2772 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2773 getDimension(), std::get<1>(pair));
2777 checkFailureInSplitting(!first.back() && !second.back(), loc);
2782 if (!second.back()) {
2783 noSecondPart = target;
2788 if (second.size() != first.size() && !second.empty()) {
2789 auto diag = emitSilenceableError()
2790 <<
"splitting does not produce the second part for a subset "
2793 <<
"expected splitting to produce the second part of all "
2794 "or none of the targets";
2796 <<
"first target with no second part";
2800 opList.append(first);
2802 opList.append(second);
2804 results.
set(cast<OpResult>(getSplitList()), opList);
2808 void SplitOp::getEffects(
2811 if (getDynamicChunkSizes())
2819 IntegerAttr staticChunkSizes;
2825 if (!dynamicPointParseResult.
has_value()) {
2826 int64_t staticChunkSizesValue;
2840 if (dynamicPointParseResult.
has_value()) {
2841 Type ChunkSizesType;
2854 SplitOp::getStaticChunkSizesAttrName(result.
name).getValue(),
2861 printer <<
" " << getTarget() <<
" after ";
2862 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
2863 if (staticChunkSize != ShapedType::kDynamic)
2864 printer << staticChunkSize;
2866 printer << getDynamicChunkSizes();
2869 {getStaticChunkSizesAttrName()});
2870 printer <<
" : " << getTarget().getType();
2871 if (staticChunkSize == ShapedType::kDynamic)
2872 printer <<
", " << getDynamicChunkSizes().getType();
2876 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2877 (getDynamicChunkSizes() ==
nullptr)) {
2878 return emitOpError() <<
"expects either a dynamic or a static split "
2879 "point to be provided";
2888 void transform::SplitReductionOp::build(
2890 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
2891 bool useScalingAlgorithm,
bool useAlloc) {
2897 SplitReductionOp::getInsertSplitDimensionAttrName(result.
name),
2899 if (innerParallel) {
2900 result.
addAttribute(SplitReductionOp::getInnerParallelAttrName(result.
name),
2903 if (useScalingAlgorithm) {
2905 SplitReductionOp::getUseScalingAlgorithmAttrName(result.
name),
2913 result.
addTypes({resultType, resultType, resultType, resultType});
2922 unsigned(getInsertSplitDimension()),
2923 bool(getInnerParallel())};
2926 FailureOr<SplitReductionResult> splitResult =
2927 (getUseScalingAlgorithm())
2931 return emitDefaultDefiniteFailure(target);
2933 results.
push_back(splitResult->initOrAlloc);
2935 results.
push_back(splitResult->splitLinalgOp);
2936 results.
push_back(splitResult->resultCombiningLinalgOp);
2944 void transform::TileReductionUsingForOp::build(
2955 build(builder, result,
2959 staticTileSizesAttr);
2968 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2969 if (!partialReductionOp) {
2972 "Operation should implement PartialReductionOpInterface");
2976 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
2977 if (reductionDims.empty()) {
2978 for (
auto [idx, iteratorType] :
2980 if (iteratorType == utils::IteratorType::reduction)
2981 reductionDims.push_back(idx);
2987 options.setReductionTilingStrategy(
2990 options.setReductionDims(reductionDims);
2991 FailureOr<scf::SCFTilingResult> result =
2996 "failed to tile using partial reduction");
2998 rewriter.
replaceOp(target, result->replacements);
2999 for (
Value initValue : result->initialValues)
3001 for (
auto parallelTiledOp : result->tiledOps)
3003 for (
auto mergeOp : result->mergeOps)
3005 results.
push_back(result->loops.front());
3013 void transform::TileReductionUsingForallOp::build(
3016 ArrayAttr mapping) {
3026 build(builder, result,
3030 staticNumThreadsAttr,
3031 staticTileSizesAttr,
3047 options.setReductionTilingStrategy(
3049 if (!getNumThreads().empty()) {
3050 options.setNumThreads(numThreads);
3052 options.setTileSizes(tileSizes);
3054 if (
auto mapping = getMapping()) {
3055 options.setMapping(mapping.value().getValue());
3058 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3059 if (reductionDims.empty()) {
3060 for (
auto [idx, iteratorType] :
3062 if (iteratorType == utils::IteratorType::reduction)
3063 reductionDims.push_back(idx);
3066 options.setReductionDims(reductionDims);
3068 rewriter, cast<TilingInterface>(target.getOperation()),
options);
3071 auto diag = emitSilenceableError() <<
"could not tile reduction";
3074 rewriter.
replaceOp(target, result->replacements);
3076 for (
Value initValue : result->initialValues)
3078 for (
auto parallelTiledOp : result->tiledOps)
3080 for (
auto mergeOp : result->mergeOps)
3082 results.
push_back(result->loops.front());
3096 llvm::to_vector(state.getPayloadOps(getTarget()));
3098 if (!llvm::hasSingleElement(targetOps)) {
3100 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
3105 auto linalgOp = dyn_cast<LinalgOp>(target);
3106 auto tileableOp = dyn_cast<TilingInterface>(target);
3111 OpBuilder builder(linalgOp.getContext());
3113 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
3114 if (linalgOp.hasDynamicShape()) {
3115 auto diag = emitSilenceableError()
3116 <<
"cannot compute parametric tile sizes for dynamically "
3117 "shaped payload op";
3118 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
3122 FailureOr<StaticContinuousTileSizeSpecification> spec =
3126 return emitSilenceableError()
3127 <<
"failed to compute multi-size tiling sizes";
3132 for (
auto &&[tileSize, tripCount] :
3133 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3134 chunkSizes.push_back(tileSize * tripCount);
3137 return llvm::map_to_vector(values, [&](int64_t value) ->
Attribute {
3142 getI64AttrsFromI64(spec->tileSizes));
3143 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
3144 getI64AttrsFromI64(chunkSizes));
3152 unsigned dimension = getDimension();
3155 builder, tileableOp, dimension, targetSize,
true);
3157 return emitSilenceableError() <<
"could not generate tile size computation";
3169 for (
auto &&[tileSize, tripCount] :
3170 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3171 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3172 chunkSizes.push_back(splitPoint);
3176 return llvm::map_to_vector(values, [&](
Value value) ->
Operation * {
3182 getDefiningOps(spec->tileSizes));
3183 transformResults.
set(cast<OpResult>(getChunkSizes()),
3184 getDefiningOps(chunkSizes));
3192 return emitOpError() <<
"expects all results type to be the same";
3198 void transform::ContinuousTileSizesOp::getEffects(
3216 Type &tileSizesType,
3217 Type &chunkSizesType) {
3218 FunctionType funcType;
3223 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3224 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
3225 "argument and one result";
3227 targetType = funcType.getInput(0);
3228 tileSizesType = chunkSizesType = funcType.getResult(0);
3237 void transform::TileUsingForOp::build(
3242 return build(builder, result, loopTypes,
3246 interchange, scalableSizes);
3249 void transform::TileUsingForOp::build(
3253 build(builder, result, target,
3255 interchange, scalableSizes);
3258 void transform::TileUsingForOp::build(
3265 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3269 void transform::TileUsingForOp::build(
3281 unsigned numExpectedLoops =
3282 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3284 resultTypes.reserve(numExpectedLoops);
3285 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3286 "expected one loop type or as many as loops");
3287 if (loopTypes.size() == 1)
3288 resultTypes.append(numExpectedLoops, loopTypes[0]);
3290 llvm::append_range(resultTypes, loopTypes);
3292 if (scalableSizes.has_value())
3293 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3294 build(builder, result, target.
getType(),
3298 staticTileSizesAttr,
3300 expandedScalableSizes);
3305 return emitOpError(
"expected same number of sizes (")
3307 << getScalableSizes().size() <<
")";
3309 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3310 if (getLoops().size() != numExpectedLoops)
3311 return emitOpError(
"expected number of loops to tile (")
3312 << numExpectedLoops <<
") to match number of `loops` results ("
3313 << getLoops().size() <<
")";
3324 llvm::to_vector(state.getPayloadOps(getTarget()));
3330 if (isa<ParamType>(transformValue.getType())) {
3331 dynamicSizeProducers.push_back({});
3333 paramSizes.push_back(
3334 llvm::to_vector(llvm::map_range(params, [](
Attribute attr) {
3335 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3338 if (paramSizes.back().size() != targets.size()) {
3340 emitSilenceableError()
3341 <<
"expected as many parameter values ("
3342 << dynamicSizeProducers.back().size() <<
") as target ops ("
3343 << targets.size() <<
")";
3344 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
3350 paramSizes.push_back({});
3351 dynamicSizeProducers.push_back(
3352 llvm::to_vector(state.getPayloadOps(transformValue)));
3354 if (dynamicSizeProducers.back().size() != targets.size()) {
3356 emitSilenceableError()
3357 <<
"expected as many dynamic size-producing operations ("
3358 << dynamicSizeProducers.back().size() <<
") as target ops ("
3359 << targets.size() <<
")";
3360 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3364 for (
Operation *op : dynamicSizeProducers.back()) {
3371 emitSilenceableError() <<
"expected sizes to be produced by ops "
3372 "with a single index-type result";
3373 diag.attachNote(op->
getLoc()) <<
"size producer op";
3374 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3381 loops.resize(getLoops().size());
3382 auto scalableSizes = getScalableSizes();
3384 auto tilingInterface = dyn_cast<TilingInterface>(op);
3385 if (!tilingInterface) {
3387 emitSilenceableError()
3388 <<
"only ops implementing TilingInterface are supported";
3389 diag.attachNote(op->
getLoc()) <<
"target op";
3392 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3394 emitSilenceableError()
3395 <<
"too many tiles provided, expected at most "
3396 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
3397 << tileSizes.size();
3398 diag.attachNote(op->
getLoc()) <<
"target op";
3403 if (tileSizes.empty()) {
3412 sizes.reserve(tileSizes.size());
3413 unsigned dynamicIdx = 0;
3416 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3417 if (scalableSizes[ofrIdx]) {
3419 b, getLoc(), cast<IntegerAttr>(attr).getInt());
3421 vector::VectorScaleOp::create(b, getLoc(), b.
getIndexType());
3423 arith::MulIOp::create(b, getLoc(), val, vscale).getResult());
3425 sizes.push_back(attr);
3432 assert((dynamicSizes.empty() ^ params.empty()) &&
3433 "expected either dynamic sizes or parameters");
3434 if (!params.empty()) {
3437 sizes.push_back(dynamicSizes[index]->getResult(0));
3445 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3446 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3447 if (
failed(maybeTilingResult))
3450 rewriter.
replaceOp(op, maybeTilingResult->replacements);
3452 tiled.append(maybeTilingResult->tiledOps);
3454 loops[en2.index()].push_back(en2.value());
3457 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3459 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3468 results.reserve(tileSizes.size());
3469 unsigned dynamicPos = 0;
3471 for (int64_t size : tileSizes) {
3472 if (size == ShapedType::kDynamic) {
3473 results.push_back(dynamic[dynamicPos++]);
3481 void transform::TileUsingForOp::getEffects(
3493 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3497 ArrayAttr mapping) {
3498 return build(builder, result,
3506 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3510 ArrayAttr mapping) {
3520 build(builder, result,
3521 TypeRange{operationType, operationType},
3528 staticTileSizesAttr,
3532 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3536 ArrayAttr mapping) {
3537 return build(builder, result, target,
3542 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3546 ArrayAttr mapping) {
3557 build(builder, result,
3558 TypeRange{operationType, operationType},
3564 staticNumThreadsAttr,
3577 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3579 for (
auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3581 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3582 normalizedUbs.push_back(normalizedUb);
3584 return normalizedUbs;
3600 for (
auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3603 denormalizedIvs.push_back(
3606 return denormalizedIvs;
3617 scf::ForallOp loop) {
3634 auto normalizedForallOp = scf::ForallOp::create(
3635 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3636 loop.getOutputs(), loop.getMapping(),
3639 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3641 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3646 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3647 normalizedForallOp.getRegionIterArgs().end());
3648 Block *origLoopBlock = loop.getBody();
3649 rewriter.
mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3651 rewriter.
replaceOp(loop, normalizedForallOp);
3652 return normalizedForallOp;
3657 TransformOpInterface transformOp,
Operation *target,
3662 auto tileableOp = dyn_cast<TilingInterface>(target);
3665 transformOp.emitSilenceableError()
3666 <<
"only TilingInterface ops are supported";
3667 diag.attachNote(target->
getLoc()) <<
"target op";
3673 if (!mixedNumThreads.empty()) {
3674 options.setNumThreads(mixedNumThreads);
3676 options.setTileSizes(mixedTileSizes);
3679 options.setMapping(mapping.value().getValue());
3681 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3684 if (
failed(maybeTilingResult))
3685 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3687 rewriter.
replaceOp(tileableOp, maybeTilingResult->replacements);
3689 tilingResult = *maybeTilingResult;
3691 if (mixedNumThreads.empty()) {
3692 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.
loops.front());
3695 scf::ForallOp normalizedForallOp =
3697 tilingResult.
loops.front() = normalizedForallOp;
3707 auto transformOp = cast<TransformOpInterface>(getOperation());
3716 getPackedNumThreads()
3718 state, transformOp, mixedNumThreads, getPackedNumThreads())
3720 state, transformOp, mixedNumThreads, getMixedNumThreads());
3724 status = getPackedTileSizes()
3726 state, transformOp, mixedTileSizes, getPackedTileSizes())
3728 state, transformOp, mixedTileSizes, getMixedTileSizes());
3732 for (
Operation *target : state.getPayloadOps(getTarget())) {
3735 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3736 getMapping(), tilingResult);
3737 if (!
diag.succeeded())
3739 tileOps.push_back(tilingResult.
loops.front());
3740 tiledOps.append(tilingResult.
tiledOps);
3743 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3744 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3749 void transform::TileUsingForallOp::getEffects(
3762 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3771 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3772 static_cast<int>(getPackedNumThreads() !=
Value());
3773 if (numThreadsSpec > 1)
3775 "num_threads and packed_num_threads are mutually exclusive");
3776 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3777 static_cast<int>(getPackedTileSizes() !=
Value());
3778 if (tileSizesSpec > 1)
3780 "tile_sizes and packed_tile_sizes are mutually exclusive");
3781 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3782 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
3783 "must be specified");
3791 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3793 bool foldTypeExtensionsIntoContract,
bool vectorizePadding,
3794 bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
3796 if (foldTypeExtensionsIntoContract) {
3798 VectorizeChildrenAndApplyPatternsOp::
3799 getFoldTypeExtensionsIntoContractAttrName(result.
name),
3802 if (vectorizePadding) {
3804 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3808 if (vectorizeExtract) {
3810 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3814 if (flatten1DDepthwiseConv) {
3816 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3827 explicit VectorizationPattern(
MLIRContext *context,
3828 bool vectorizeExtract =
false,
3829 bool flattenConv =
false)
3831 vectorizeNDExtract(vectorizeExtract),
3832 flatten1DDepthwiseConv(flattenConv) {}
3833 LogicalResult matchAndRewrite(
Operation *op,
3837 "Unsupported Op, cannot vectorize");
3838 FailureOr<VectorizationResult> vectorResults =
3840 {}, vectorizeNDExtract,
3841 flatten1DDepthwiseConv);
3842 if (
failed(vectorResults))
3844 rewriter.
replaceOp(op, vectorResults->replacements);
3851 bool vectorizeNDExtract =
false;
3855 bool flatten1DDepthwiseConv =
false;
3860 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3865 auto diag = this->emitOpError(
"requires isolated-from-above targets");
3866 diag.attachNote(target->
getLoc()) <<
"non-isolated target";
3872 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3873 getFlatten_1dDepthwiseConv());
3875 if (!getDisableTransferPermutationMapLoweringPatterns())
3878 if (!getDisableMultiReductionToContractPatterns())
3886 vector::TransferReadOp::getCanonicalizationPatterns(
patterns, ctx);
3887 vector::TransferWriteOp::getCanonicalizationPatterns(
patterns, ctx);
3892 if (getFoldTypeExtensionsIntoContract())
3895 if (getVectorizePadding()) {
3907 return emitDefaultDefiniteFailure(target);
3921 auto targets = state.getPayloadOps(getTarget());
3922 if (std::empty(targets))
3924 auto transformOp = cast<TransformOpInterface>(getOperation());
3927 state, transformOp, getMixedVectorSizes(), vectorSizes);
3935 <<
"Unsupported Op, cannot vectorize";
3937 FailureOr<VectorizationResult> vectorResults =
3939 getVectorizeNdExtract().value_or(
false),
3941 getAssumeDynamicDimsMatchVecSizes().value_or(
false),
3942 getCreateNamedContraction().value_or(
false));
3943 if (
failed(vectorResults)) {
3945 <<
"Attempted to vectorize, but failed";
3947 rewriter.
replaceOp(target, vectorResults->replacements);
3953 void transform::VectorizeOp::getEffects(
3962 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3966 if (getStaticVectorSizes().size() != getScalableSizes().size())
3967 return emitOpError(
"expected same number of vector sizes (")
3968 << getStaticVectorSizes().size() <<
") and scalable sizes ("
3969 << getScalableSizes().size() <<
")";
3978 transform::HoistRedundantVectorTransfersOp::applyToOne(
3995 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4014 auto maybeTransformed =
4017 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4020 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4023 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4026 .Case([&](linalg::Conv2DNchwFchwOp op) {
4032 if (
failed(maybeTransformed))
4033 return emitDefaultSilenceableFailure(target);
4035 results.
push_back(maybeTransformed->first);
4037 results.
push_back(maybeTransformed->second);
4052 <<
"only elementwise flattening is supported";
4055 if (target.getNumLoops() <= 1) {
4062 std::iota(reassociation.begin(), reassociation.end(), 0);
4063 auto maybeFlattened =
4065 if (
failed(maybeFlattened))
4067 <<
"attempted to flatten, but failed";
4068 results.
push_back(maybeFlattened->collapsedOp);
4069 rewriter.
replaceOp(target, maybeFlattened->results);
4082 auto maybeTransformed =
4084 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4087 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4093 if (
failed(maybeTransformed))
4094 return emitDefaultSilenceableFailure(target);
4109 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4110 auto maybeTransformed =
4112 .Case([&](linalg::MatmulOp op) {
4115 .Case([&](linalg::BatchMatmulOp op) {
4118 .Default([&](
Operation *op) {
return failure(); });
4119 if (
failed(maybeTransformed))
4129 template <
typename OpTy>
4133 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4134 tensor::ParallelInsertSliceOp>() &&
4137 if (
auto copySource =
4138 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4145 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
4147 target->template getParentOfType<scf::InParallelOp>());
4150 Value extracted = tensor::ExtractSliceOp::create(
4151 rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
4152 target.getMixedSizes(), target.getMixedStrides());
4153 Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
4154 target.getSource(), extracted)
4159 target, copied, target.getDest(), target.getMixedOffsets(),
4160 target.getMixedSizes(), target.getMixedStrides());
4172 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4173 return doit(rewriter, target, results, state);
4174 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4175 return doit(rewriter, target, results, state);
4178 emitSilenceableError()
4179 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4180 diag.attachNote(targetOp->
getLoc()) <<
"target op";
4193 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4195 emitSilenceableError()
4196 <<
"only linalg.copy and tensor.pad target ops are supported";
4197 diag.attachNote(target->
getLoc()) <<
"target op";
4200 assert(target->
getNumResults() == 1 &&
"expected single result");
4202 if (!resultShapedType.hasStaticShape()) {
4204 emitSilenceableError()
4205 <<
"only statically sized ops of rank <= 3 are supported";
4206 diag.attachNote(target->
getLoc()) <<
"target op";
4211 int64_t desiredBitAlignment = getDesiredBitAlignment();
4212 int64_t eltBitwidth =
4213 resultShapedType.getElementType().getIntOrFloatBitWidth();
4214 if (desiredBitAlignment % eltBitwidth != 0) {
4215 desiredBitAlignment = eltBitwidth;
4220 getTotalNumThreads(),
4221 desiredBitAlignment,
4222 resultShapedType.getShape(),
4225 resultShapedType.getElementType().getIntOrFloatBitWidth());
4226 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4228 emitSilenceableError()
4229 <<
"too few threads to map copy op to threads on the most minor "
4230 "dimension, given alignment and vector size constraints, try "
4231 "smaller tile size of mapping to more threads";
4232 diag.attachNote(target->
getLoc()) <<
"target op";
4248 if (!
diag.succeeded())
4252 for (
auto op : tilingResult.
tiledOps)
4266 FailureOr<Operation *> maybeTransformed = failure();
4268 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4273 .Default([&](
Operation *op) {
return false; });
4276 return emitSilenceableError()
4277 <<
"this operation is not supported to convert to Winograd Conv2D";
4280 if (
failed(maybeTransformed)) {
4281 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
4293 FailureOr<Operation *> maybeTransformed = failure();
4296 .Case([&](linalg::WinogradFilterTransformOp op) {
4300 .Case([&](linalg::WinogradInputTransformOp op) {
4304 .Case([&](linalg::WinogradOutputTransformOp op) {
4308 .Default([&](
Operation *op) {
return false; });
4312 emitSilenceableError()
4313 <<
"this operation is not supported to decompose into other operations";
4314 diag.attachNote(target->
getLoc()) <<
"target op";
4318 if (
failed(maybeTransformed)) {
4320 emitSilenceableError() <<
"decompose Winograd operations failed";
4321 diag.attachNote(target->
getLoc()) <<
"target op";
4329 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4331 #define GET_OP_CLASSES
4332 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
This class represents a saved insertion point.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getOpResult(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
@ MaterializeInDestination
Transformation to drop unit-extent dimensions from linalg.generic operations.
Vectorization pattern for memref::CopyOp.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Match and rewrite for the pattern:
Match and rewrite for the pattern:
@ BufferizationMaterializeInDestination
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.