43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/DebugLog.h"
47 #include "llvm/Support/LogicalResult.h"
48 #include <type_traits>
54 #define DEBUG_TYPE "linalg-transforms"
61 template <
typename PatternTy,
typename... Args>
64 using OpTy =
typename llvm::function_traits<
65 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
66 auto op = dyn_cast<OpTy>(operation);
71 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
76 auto result = pattern.returningMatchAndRewrite(op, rewriter);
79 return cast<LinalgOp>(result->getOperation());
89 if (
auto attr = dyn_cast<Attribute>(ofr)) {
90 if (!isa<IntegerAttr>(attr))
91 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
92 result.push_back(ofr);
96 Value transformValue = cast<Value>(ofr);
97 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
99 if (params.size() != 1)
100 return transformOp.emitDefiniteFailure()
101 <<
"requires exactly one parameter associated";
102 result.push_back(params[0]);
106 auto payloadOps = state.getPayloadOps(transformValue);
107 if (!llvm::hasSingleElement(payloadOps)) {
109 transformOp.emitSilenceableError()
110 <<
"handle must be mapped to exactly one payload op";
112 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
119 transformOp.emitSilenceableError()
120 <<
"payload op must have exactly 1 index result";
140 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
142 for (
auto param : params) {
143 if (!isa<IntegerAttr>(param))
144 return transformOp.emitDefiniteFailure()
145 <<
"expected the parameter to be associated with an integer "
147 result.push_back(param);
152 for (
Operation *op : state.getPayloadOps(packedHandle)) {
153 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
155 transformOp.emitSilenceableError()
156 <<
"payload op must have exactly 1 index result";
157 diag.attachNote(op->getLoc())
158 <<
"has " << op->getNumResults() <<
" results";
161 result.push_back(op->getResult(0));
175 if (
auto attr = dyn_cast<Attribute>(paramOrHandle)) {
176 reified.push_back(cast<IntegerAttr>(attr).getInt());
178 }
else if (isa<ParamType>(cast<Value>(paramOrHandle).
getType())) {
180 if (params.size() != 1)
181 return transformOp.emitSilenceableError() <<
"expected a single param";
183 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
187 Value handle = cast<Value>(paramOrHandle);
188 if (!isa<TransformHandleTypeInterface>(handle.
getType()))
189 return transformOp.emitSilenceableError() <<
"unexpected value handle";
190 auto payload = state.getPayloadOps(handle);
191 if (!llvm::hasSingleElement(payload))
192 return transformOp.emitSilenceableError()
193 <<
"requires param or handle that is mapped to 1 payload op";
195 Operation *paramOrHandlePayloadOp = *payload.begin();
198 return transformOp.emitSilenceableError()
199 <<
"requires param or handle to be result of op with 1 index "
205 return transformOp.emitSilenceableError()
206 <<
"requires param or handle to be the result of a constant like "
209 reified.push_back(attr.getInt());
218 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
223 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
228 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
233 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
239 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
242 options.rankReductionStrategy =
247 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
252 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
257 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
262 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
267 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
276 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
281 resultTypes.push_back(b.
getType<transform::AnyValueType>());
282 resultTypes.push_back(b.
getType<transform::AnyOpType>());
283 return build(b, result,
289 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
292 int64_t memorySpace) {
294 resultTypes.push_back(b.
getType<transform::AnyValueType>());
295 resultTypes.push_back(b.
getType<transform::AnyOpType>());
296 return build(b, result,
312 void notifyOperationInserted(
Operation *op,
314 ForwardingListener::notifyOperationInserted(op, previous);
316 if (previous.
isSet())
318 auto inserted = newOps.insert(op);
320 assert(inserted.second &&
"expected newly created op");
323 void notifyOperationErased(
Operation *op)
override {
324 ForwardingListener::notifyOperationErased(op);
338 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
339 NewOpsListener newOpsListener(previousListener);
343 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
346 }
else if (getMemcpyOp() ==
"memref.copy") {
349 }
else if (getMemcpyOp() ==
"linalg.copy") {
353 llvm_unreachable(
"invalid memcpy op");
355 if (getAllocOp() ==
"memref.alloc") {
358 }
else if (getAllocOp() ==
"memref.alloca") {
362 llvm_unreachable(
"invalid alloc op");
364 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
365 options.emitDealloc = getEmitDealloc();
369 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
371 for (
Operation *op : state.getPayloadOps(getTarget())) {
376 <<
"failed to bufferize operation";
377 diag.attachNote(op->
getLoc()) <<
"target payload op";
380 allocatedBuffers.push_back(buffer);
384 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
385 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
389 void transform::BufferizeToAllocationOp::getEffects(
391 if (getBufferizeDestinationOnly()) {
403 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
404 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
405 return emitOpError() <<
"unsupported memcpy op";
406 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
407 return emitOpError() <<
"unsupported alloc op";
420 #define DOWNSCALE(trans) \
422 FailureOr<LinalgOp> res = tryApply<trans>(target); \
423 if (succeeded(res)) { \
424 results.push_back(*res); \
425 return DiagnosedSilenceableFailure::success(); \
429 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
430 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
443 #undef DOWNSCALE_NORMAL
444 #undef DOWNSCALE_CALL
446 return emitDefaultSilenceableFailure(target);
460 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
461 if (!decomposableOp) {
463 "payload is not a decomposable op"));
464 return emitDefaultSilenceableFailure(target);
467 FailureOr<SmallVector<Value>> maybeNewResults =
468 decomposableOp.decomposeOperation(rewriter);
469 if (
failed(maybeNewResults))
470 return emitDefaultSilenceableFailure(target);
472 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
473 for (
Value val : *maybeNewResults) {
474 Operation *definition = val.getDefiningOp();
485 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
492 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
496 options.allowReturnAllocsFromLoops =
true;
498 for (
Operation *target : state.getPayloadOps(getTarget())) {
502 <<
"failed to analyze op";
504 rewriter, target, state)))
506 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
517 template <
typename Range>
521 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
527 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
528 if (!tilingInterfaceOp)
529 return transformOp->
emitError(
"only TilingInterface ops are supported");
532 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
533 applyFn(tilingInterfaceOp);
539 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
540 for (
Operation *toReplace : opsToReplace) {
541 for (
OpResult res : toReplace->getResults())
542 if (
auto replacement = tiledResults->replacements.lookup(res))
544 if (toReplace->use_empty()) {
550 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
551 assert(tiledResults->loops.size() == numLoops &&
552 "Mismatched number of loops, tile and fuse transform should have "
554 for (
unsigned int i = 0; i < numLoops; ++i)
555 loopOps[i].push_back(tiledResults->loops[i]);
558 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
559 for (
unsigned int i = 0; i < numLoops; ++i)
560 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
572 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
578 tilingOptions = tilingOptions.
setTileSizes(tileSizesOfr);
582 if (getApplyCleanup()) {
585 tensor::ExtractSliceOp::getCanonicalizationPatterns(
patterns, context);
592 rewriter, getOperation(), state.getPayloadOps(getTarget()),
593 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
594 [&](TilingInterface tilingInterfaceOp)
595 -> FailureOr<scf::SCFTileAndFuseResult> {
596 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
605 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
606 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
607 if (!std::is_permutation(sequence.begin(), sequence.end(),
608 permutation.begin(), permutation.end())) {
609 return emitOpError() <<
"expects interchange to be a permutation, found "
610 << getTileInterchange();
615 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
616 if (numExpectedLoops != getNumResults() - 1)
617 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
626 void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
629 Value containingOp) {
632 result.
addTypes({resultType, resultType});
648 (domInfo.
dominates(containingOp, user))) {
649 dominatedUsers.insert(user);
652 if (dominatedUsers.empty())
656 auto forallOp = cast<scf::ForallOp>(containingOp);
662 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
667 newOuts.push_back(outputs[resultNumber]);
670 auto newforallOp = scf::ForallOp::create(
671 rewriter, loc, forallOp.getMixedLowerBound(),
672 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
673 forallOp.getMapping());
675 newforallOp.getRegion().takeBody(forallOp.getRegion());
680 newforallOp.getBody()->addArgument(newOuts.back().getType(),
681 newOuts.back().getLoc());
682 auto bbArgs = newforallOp.getBody()->getArguments();
685 Operation *op = use.getOwner();
686 return newforallOp->isProperAncestor(op);
690 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
692 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; }));
693 Operation *firstYieldOp = yieldingOps.front();
696 Value dst = newforallOp.getRegionIterArgs().back();
698 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->
getLoc(), src,
699 dst, offsets, sizes, strides);
703 newforallOp->getResult(result.index()));
706 newforallOp->getResults().back(),
708 Operation *user = use.getOwner();
709 return dominatedUsers.contains(user);
723 destWorklist.push_back(dst);
725 while (!destWorklist.empty()) {
726 Value currentDst = destWorklist.pop_back_val();
730 if (src == currentDst)
735 auto bbArg = dyn_cast<BlockArgument>(currentDst);
739 Block *parentBlock = bbArg.getOwner();
740 assert(parentBlock &&
"unlinked block argument");
743 assert(parentOp &&
"expected block argument with parent operation");
746 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
750 for (
auto innerIterArg : parentLoop.getRegionIterArgs()) {
752 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
753 Value loopBlockArgument =
755 destWorklist.push_back(loopBlockArgument);
768 static std::tuple<SmallVector<Operation *>,
Operation *>
771 LDBG() <<
"Try to fuse a direct extract use";
772 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
773 if (!tileableProducer) {
775 <<
"producer is not a TileableInterface: " << *producerOp;
782 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
783 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
784 return sliceOp && containingOp->isProperAncestor(sliceOp);
788 if (it == tileableProducer->getUsers().end()) {
789 diag.attachNote(tileableProducer->getLoc())
790 <<
"could not find fusion opportunity for: " << *tileableProducer;
793 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
806 if (LoopLikeOpInterface containerLoop =
807 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
814 cast<DestinationStyleOpInterface>(
clone).getDpsInitsMutable()) {
815 Value producerOperand =
818 containerLoop.getRegionIterArgs()) {
819 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
820 Value consumerOperand =
821 containerLoop->getOperand(bbArg->getOperandNumber());
823 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
824 initOperandPtr.set(containerIterArg);
830 tileableProducer = dyn_cast<TilingInterface>(
clone);
834 int64_t resultNumber =
835 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
836 LDBG() <<
"resultNumber: " << resultNumber;
841 FailureOr<TilingResult> tileAndFuseResult =
842 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
845 if (
failed(tileAndFuseResult)) {
846 diag.attachNote(tileableProducer->getLoc())
847 <<
"failed to tile producer op: " << *tileableProducer;
852 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
853 LDBG() <<
"tiledProducer: " << *tiledOp;
858 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
859 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
860 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
861 if (
failed(maybeRankReduced)) {
863 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
864 << tileAndFuseResult->tiledValues[0]
865 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
868 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
872 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
873 resultNumber, offsets, sizes);
876 if (dyn_cast<LoopLikeOpInterface>(containingOp))
877 rewriter.
eraseOp(tileableProducer);
879 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
892 LDBG() <<
"Try to fuse an extract use through block argument";
894 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
895 if (!tileableProducer) {
897 <<
"producer is not a TileableInterface: " << *producerOp;
902 scf::ForallOp forallOp;
903 auto itProducerUses =
904 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
905 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
909 if (!forallOp || forallOp != containingOp) {
910 diag.attachNote(tileableProducer->getLoc())
911 <<
"could not find a use by the containing op: " << *tileableProducer;
926 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
927 return sliceOp && containingOp->isProperAncestor(sliceOp);
931 if (itBBArgUsers == bbArg.
getUsers().end()) {
933 <<
"could not find fusion opportunity for bbArg: " << bbArg;
936 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
944 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
945 LDBG() <<
"resultNumber: " << resultNumber;
950 rewriter, tileableProducer->getLoc(), tileableProducer,
951 destinationTensors))) {
952 diag.attachNote(tileableProducer->getLoc())
953 <<
"failed to get destination tensors for: " << *tileableProducer;
958 bvm.
map(destinationTensors[resultNumber], bbArg);
959 auto tileableProducerClone =
960 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
962 llvm::make_scope_exit([&]() { rewriter.
eraseOp(tileableProducerClone); });
965 FailureOr<TilingResult> tileAndFuseResult =
966 tileableProducerClone.generateResultTileValue(
967 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
968 sliceOpToTile.getMixedSizes());
969 if (
failed(tileAndFuseResult)) {
970 diag.attachNote(tileableProducer->getLoc())
971 <<
"failed to tile producer op: " << *tileableProducer;
976 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
977 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
978 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
979 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
980 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
985 destinationTensors.front());
988 return tileAndFuseResult->tiledOps;
994 LDBG() <<
"Try to fuse an use by cloning";
999 for (
OpOperand &use : result.getUses()) {
1001 uses.push_back(&use);
1006 if (containingOp == use.getOwner()) {
1008 <<
"producer op use by containing op cannot be fused by cloning";
1016 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
1025 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
1026 "Parallel insert slice is not a valid clone destination");
1027 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
1028 LDBG() <<
"resultNumber: " << resultNumber;
1032 fusedOp = rewriter.
clone(*producerOp);
1034 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1039 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1049 auto producerOps = state.getPayloadOps(getProducerOp());
1050 auto containingOps = state.getPayloadOps(getContainingOp());
1051 if (!llvm::hasSingleElement(containingOps)) {
1053 <<
"requires exactly one containing_op handle (got "
1054 << llvm::range_size(containingOps) <<
")";
1056 Operation *containingOp = *containingOps.begin();
1059 if (std::empty(producerOps)) {
1061 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1068 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1069 for (
const auto &it :
enumerate(remainingProducers)) {
1072 int64_t numUsesInContainingOp =
1074 return containingOp->isAncestor(op);
1079 if (numUsesInContainingOp > 0) {
1080 if (numUsesInContainingOp == 1)
1081 remainingProducers.erase(remainingProducers.begin() + it.index());
1088 while (!remainingProducers.empty()) {
1089 auto nextProducer = getNextProducer();
1090 if (
failed(nextProducer)) {
1092 <<
"could not find next producer to fuse into container";
1093 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
1101 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
1108 auto [tiledOps, newContainingOp] =
1110 if (!tiledOps.empty()) {
1111 LDBG() <<
"\nFused a direct extract use\n" << *containingOp;
1112 fusedOps.append(tiledOps);
1113 if (newContainingOp) {
1121 LogicalResult replacementStatus =
1124 (void)replacementStatus;
1125 assert(succeeded(replacementStatus) &&
1126 "unable to update transform state mapping");
1127 rewriter.
eraseOp(containingOp);
1128 containingOp = newContainingOp;
1135 rewriter,
diag, producerOp, containingOp);
1136 if (!tiledContainingOpOperand.empty()) {
1137 LDBG() <<
"\nFused an extract use through block argument\n"
1139 fusedOps.append(tiledContainingOpOperand);
1146 LDBG() <<
"\nFused an use by cloning\n" << *containingOp;
1147 fusedOps.push_back(cloned);
1153 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1154 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1158 void transform::FuseIntoContainingOp::getEffects(
1176 if (isa<GenericOp>(target)) {
1182 if (succeeded(
generic)) {
1183 results.
push_back(generic->getOperation());
1186 return emitDefaultSilenceableFailure(target);
1199 if (!isa<GenericOp>(target)) {
1204 FailureOr<LinalgOp> named =
1206 if (succeeded(named)) {
1207 results.
push_back(named->getOperation());
1210 return emitDefaultSilenceableFailure(target);
1224 if (interchangeVector.empty()) {
1229 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1230 if (interchangeVector.size() != numLoops) {
1231 return emitSilenceableError()
1232 << getIteratorInterchangeAttrName() <<
" has length ("
1233 << interchangeVector.size()
1234 <<
") different from the number of loops in the target operation ("
1247 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1248 if (!std::is_permutation(sequence.begin(), sequence.end(),
1249 permutation.begin(), permutation.end())) {
1250 return emitOpError()
1251 <<
"expects iterator_interchange to be a permutation, found "
1252 << getIteratorInterchange();
1267 if (!isa<linalg::CopyOp>(targetOp)) {
1269 emitSilenceableError() <<
"only linalg.copy target ops are supported";
1270 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1274 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1275 if (!copyOp.hasPureBufferSemantics()) {
1277 emitSilenceableError()
1278 <<
"cannot transform a linalg.copy on tensors into a memref.copy";
1279 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1285 assert(inputs.size() == 1 &&
"expected linalg copy op with one input");
1286 assert(outputs.size() == 1 &&
"expected memref copy op with one output");
1287 Value input = inputs.front();
1288 Value output = outputs.front();
1293 if (!isa<ShapedType>(input.
getType())) {
1295 emitSilenceableError()
1296 <<
"cannot transform a linalg.copy which input has no shape";
1297 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1302 assert(isa<ShapedType>(output.getType()));
1304 if (cast<ShapedType>(input.
getType()).getElementType() !=
1305 cast<ShapedType>(output.getType()).getElementType()) {
1307 emitSilenceableError()
1308 <<
"cannot transform a linalg.copy with different source and "
1309 "destination element types ";
1310 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1331 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1332 FailureOr<LowerPackResult> res =
1333 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1336 <<
"cannot lower to pad + expand + transpose";
1339 transformResults.
push_back(res->expandShapeOp);
1340 transformResults.
push_back(res->transposeOp);
1353 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1354 FailureOr<LowerUnPackOpResult> res =
1355 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1358 emitSilenceableError()
1359 <<
"cannot lower to transpose + collapse + extract";
1360 diag.attachNote(target->getLoc()) <<
"target payload op";
1363 transformResults.
push_back(res->emptyOp);
1364 transformResults.
push_back(res->transposeOp);
1365 transformResults.
push_back(res->collapseShapeOp);
1366 transformResults.
push_back(res->extractSliceOp);
1396 if (getOps().has_value())
1397 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1399 auto payloadOps = state.getPayloadOps(getTarget());
1400 if (!llvm::hasSingleElement(payloadOps)) {
1405 bool incorrectNumOperandTypes =
false;
1412 if (getInterface().has_value()) {
1413 auto iface = getInterface().value();
1414 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1417 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1418 !isa<TilingInterface>(op))
1420 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1421 !isa<LoopLikeOpInterface>(op))
1426 if (getOpAttrs().has_value()) {
1427 DictionaryAttr opAttrs = getOpAttrs().value();
1429 if (attr.getName() == getInterfaceAttrName() ||
1430 attr.getName() == getOpsAttrName())
1432 if (!op->
hasAttr(attr.getName()))
1434 if (op->
getAttr(attr.getName()) != attr.getValue())
1439 if (getFilterResultType().has_value()) {
1440 Type t = getFilterResultType().value();
1445 if (getFilterOperandTypes().has_value()) {
1446 mlir::ArrayAttr types = getFilterOperandTypes().value();
1449 if (types.size() == 1) {
1452 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1453 Type t = cast<::mlir::Type>(typeattr.getValue());
1455 [&](
Type operandType) { return operandType == t; }))
1460 if (types.size() != operandTypes.size()) {
1461 incorrectNumOperandTypes =
true;
1465 for (
auto [attr, operandType] :
1466 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1467 auto typeattr = cast<mlir::TypeAttr>(attr);
1468 Type type = cast<::mlir::Type>(typeattr.getValue());
1470 if (type != operandType)
1481 (*payloadOps.begin())->
walk(matchFun);
1482 if (incorrectNumOperandTypes)
1484 "type, then it must contain as much types as "
1485 "the number of operands in the target ops");
1486 results.
set(cast<OpResult>(getResult()), res);
1501 Type &targetType,
Type &lowSizeType,
1503 Type &splitPointType) {
1504 FunctionType funcType;
1509 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1510 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1511 "argument and one result";
1513 targetType = funcType.getInput(0);
1514 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1522 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1523 if (target.hasDynamicShape()) {
1524 auto diag = emitSilenceableError()
1525 <<
"cannot compute parametric tile sizes for dynamically "
1526 "shaped payload op";
1527 diag.attachNote(target->getLoc()) <<
"payload op";
1532 target, getDimension(), getTargetSize(), getDivisor());
1534 return emitSilenceableError()
1535 <<
"failed to compute multi-size tiling sizes";
1538 Builder builder(target.getContext());
1539 results.
assign(llvm::map_range(
1541 spec->lowTileSize * spec->lowTripCount}),
1542 [&builder,
this](int64_t value) {
1554 builder, target, getDimension(), targetSize, divisor);
1556 return emitSilenceableError() <<
"could not generate tile size computation";
1563 {spec->lowTileSize, spec->lowTripCount});
1564 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1565 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1566 assert(lowTileSize && highTileSize && splitPoint &&
1567 "tile sizes are not produced by operations");
1575 void transform::MultiTileSizesOp::getEffects(
1579 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1588 return emitOpError() <<
"expects all results type to be the same";
1608 builder.
getContext(), GenericOp::getOperationName());
1609 build(builder, result,
1618 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1625 auto targetOps = state.getPayloadOps(getTarget());
1627 if (std::empty(targetOps)) {
1628 transformResults.
set(cast<OpResult>(getPackedOp()),
1633 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1634 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1635 return emitSilenceableError()
1636 <<
"requires target to map to exactly 1 LinalgOp (got "
1637 << llvm::range_size(targetOps) <<
")";
1640 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1641 return emitSilenceableError()
1642 <<
"requires number of packed sizes match the number of loops ("
1643 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1650 state, *
this, packedSizes, getMixedPackedSizes());
1653 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1657 transformResults.
set(cast<OpResult>(getPackedOp()),
1658 {maybeResult->packedLinalgOp.getOperation()});
1662 void transform::PackOp::getEffects(
1676 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1677 <<
" is not a valid permutation";
1680 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1681 for (
auto [s, nmo] :
1682 llvm::zip_equal(getMixedMatmulPackedSizes(),
1683 getMatmulPaddedSizesNextMultipleOf())) {
1686 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1687 return emitOpError() <<
"at most one of the packed_size and the "
1688 "padded_sizes_next_multiple_of can be nonzero "
1689 "for the matmul strategy";
1701 for (
Operation *op : state.getPayloadOps(getTarget())) {
1702 auto linalgOp = dyn_cast<LinalgOp>(op);
1713 getMixedMatmulPackedSizes(),
1715 getMatmulPaddedSizesNextMultipleOf(),
1716 getMatmulInnerDimsOrder());
1717 if (succeeded(packResult)) {
1718 results.push_back(packResult->packedLinalgOp);
1721 results.push_back(linalgOp);
1723 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1729 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1733 void transform::PackGreedilyOp::getEffects(
1747 return emitOpError() << getInnerPermAttrName()
1748 <<
" is not a valid permutation";
1751 return emitOpError() << getOuterPermAttrName()
1752 <<
" is not a valid permutation";
1754 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1755 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1756 <<
" or " << getOuterPermAttrName()
1757 <<
" must be specified";
1763 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1773 template <
typename RelayoutOpTy>
1776 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1778 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1779 "applies to only pack or unpack operations");
1780 if (!op || permutation.empty())
1782 size_t innerRank = op.getInnerDimsPos().size();
1783 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1787 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1788 return permutation.size() == op.getSourceRank() &&
1791 return permutation.size() == op.getDestRank() &&
1799 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1800 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1802 if (std::empty(packOrUnpackOps)) {
1803 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1804 transformResults.
set(cast<OpResult>(getPackOp()), {});
1805 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1811 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1812 !llvm::hasSingleElement(linalgOps)) {
1813 return emitSilenceableError()
1814 <<
"requires target to map to exactly 1 "
1815 "packing op and 1 packed op ("
1816 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
1817 << llvm::range_size(linalgOps) <<
")";
1821 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1822 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1823 if ((!packOp && !unPackOp)) {
1824 return emitSilenceableError() <<
"requires target to map to a "
1825 "linalg.pack or linalg.unpack";
1827 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1828 if (!linalgOpTarget)
1829 return emitSilenceableError() <<
"requires a LinalgOp target";
1833 if (packOp && packOp.getResult().hasOneUse())
1834 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1836 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1837 if (linalgOp != linalgOpTarget) {
1839 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
1840 : StringLiteral{
"not produced by the LinalgOp target"};
1841 return emitSilenceableError() << errorMsg;
1847 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
1848 OpOperand *packUse = linalgOp.getDpsInitOperand(
1849 cast<OpResult>(unPackOp.getSource()).getResultNumber());
1851 if (!packOp || !packOp.getResult().hasOneUse())
1852 return emitSilenceableError() <<
"could not find matching pack op";
1856 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1858 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1859 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1860 ? StringLiteral{
"invalid outer_perm"}
1861 : StringLiteral{
"invalid inner_perm"};
1865 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1866 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
1872 assert(packOp && linalgOp &&
"unexpected null op");
1876 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1878 assert(succeeded(res) &&
"unexpected packTranspose failure");
1881 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1882 transformResults.
set(cast<OpResult>(getPackedOp()),
1883 {res->transposedLinalgOp});
1885 transformResults.
set(cast<OpResult>(getUnPackOp()),
1886 {res->transposedUnPackOp});
1888 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1903 StringRef copyBackOp,
1904 bool usePrescribedTensorShapes) {
1914 (padToMultipleOf.empty()
1916 : b.getDenseI64ArrayAttr(padToMultipleOf)),
1917 b.getI64ArrayAttr(nofoldFlags),
1918 b.getArrayAttr(transposePaddings),
1919 b.getStringAttr(copyBackOp),
1921 usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
1929 StringRef copyBackOp,
1930 bool usePrescribedTensorShapes) {
1935 staticPadToMultipleOf);
1942 dynamicPadToMultipleOf,
1943 staticPadToMultipleOf,
1947 usePrescribedTensorShapes);
1950 void PadOp::getEffects(
1960 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1967 auto transformOp = cast<TransformOpInterface>(getOperation());
1970 for (
Operation *target : state.getPayloadOps(getTarget())) {
1971 auto linalgTarget = dyn_cast<LinalgOp>(target);
1972 if (!linalgTarget) {
1973 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
1974 diag.attachNote(target->
getLoc()) <<
"target op";
1980 for (int64_t packPadding :
1981 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1982 nofoldFlags.push_back(
static_cast<bool>(packPadding));
1986 for (
auto const &[untypedAttr, elementOrTensorType] :
1987 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1989 if (isa<ub::PoisonAttr>(untypedAttr)) {
1990 paddingValues.push_back(untypedAttr);
1993 auto attr = dyn_cast<TypedAttr>(untypedAttr);
1995 emitOpError(
"expects padding values to be typed attributes or poison");
2000 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2004 if (!parsedAttr || parsedAttr.getType() != elementType) {
2005 auto diag = this->emitOpError(
"expects a padding that parses to ")
2006 << elementType <<
", got " << untypedAttr;
2007 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2010 paddingValues.push_back(parsedAttr);
2014 if (attr.getType() != elementType) {
2015 auto diag = this->emitOpError(
"expects a padding value of type ")
2016 << elementType <<
", got " << attr;
2017 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2020 paddingValues.push_back(attr);
2025 for (
Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2026 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2027 cast<ArrayAttr>(transposeVector)));
2032 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2036 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2039 if (padToMultipleOf.empty())
2043 options.padToMultipleOf = padToMultipleOf;
2044 options.paddingValues = paddingValues;
2045 options.nofoldFlags = nofoldFlags;
2046 if (getCopyBackOp() ==
2047 bufferization::MaterializeInDestinationOp::getOperationName()) {
2050 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2052 }
else if (getCopyBackOp() == kCopyOpNone) {
2055 llvm_unreachable(
"unsupported copy_back op");
2058 bool irChanged =
false;
2059 if (getUsePrescribedTensorShapes() &&
2060 linalgTarget.hasPureTensorSemantics()) {
2063 for (
OpOperand &operand : linalgTarget->getOpOperands()) {
2064 for (
auto [i, dim] :
llvm::enumerate(linalgTarget.getShape(&operand))) {
2065 if (ShapedType::isStatic(dim))
2067 options.setSizeToPadTo(operand.getOperandNumber(), i,
2069 operand.get().getLoc(),
2079 replacements, newPadOps))) {
2082 diag.attachNote(target->
getLoc()) <<
"target op";
2085 auto diag = emitSilenceableError() <<
"failed to pad op";
2086 diag.attachNote(target->
getLoc()) <<
"target op";
2095 rewriter.
replaceOp(linalgTarget, replacements);
2096 paddedOps.push_back(paddedOp);
2097 padOps.append(newPadOps.begin(), newPadOps.end());
2099 for (
Value v : replacements) {
2100 Operation *copyBackOp = v.getDefiningOp();
2101 if (!llvm::is_contained(copyBackOps, copyBackOp))
2102 copyBackOps.push_back(copyBackOp);
2107 results.
set(cast<OpResult>(getPadded()), paddedOps);
2108 results.
set(cast<OpResult>(getPad()), padOps);
2109 results.
set(cast<OpResult>(getCopy()), copyBackOps);
2115 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2116 if (any_of(nofoldFlags, [](int64_t packPadding) {
2117 return packPadding != 0 && packPadding != 1;
2119 return emitOpError()
2120 <<
"expects nofold_flags to contain booleans (0/1), found "
2121 << getNofoldFlags();
2125 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2126 if (any_of(paddingDimensions,
2127 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
2128 return emitOpError() <<
"expects padding_dimensions to contain positive "
2130 << getPaddingDimensions();
2132 if (!getMixedPadToMultipleOf().empty()) {
2133 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2134 return emitOpError() <<
"expects as many multiples as padding_dimensions";
2137 ArrayAttr transposes = getTransposePaddings();
2140 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2141 if (!std::is_permutation(sequence.begin(), sequence.end(),
2142 transpose.begin(), transpose.end())) {
2143 return emitOpError()
2144 <<
"expects transpose_paddings to be a permutation, found "
2148 if (getCopyBackOp() !=
2149 bufferization::MaterializeInDestinationOp::getOperationName() &&
2150 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2151 getCopyBackOp() != kCopyOpNone)
2152 return emitOpError() <<
"invalid copy_back_op";
2160 void transform::PadTilingInterfaceOp::build(
OpBuilder &b,
2164 bool padToMultipleOf) {
2174 : b.getDenseI64ArrayAttr(paddingSizes)),
2176 padToMultipleOf ? b.getUnitAttr() : nullptr);
2179 void transform::PadTilingInterfaceOp::build(
2186 staticPaddingSizes);
2192 dynamicPaddingSizes,
2197 void transform::PadTilingInterfaceOp::getEffects(
2206 transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2208 return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2217 for (
Operation *target : state.getPayloadOps(getTarget())) {
2218 auto targetOp = dyn_cast<TilingInterface>(target);
2220 auto diag = emitSilenceableError() <<
"expected TilingInterface target";
2221 diag.attachNote(target->
getLoc()) <<
"target op";
2228 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2229 auto diag = emitSilenceableError() <<
"only IndexingMapOpInterface ops "
2231 diag.attachNote(target->
getLoc()) <<
"target op";
2237 for (
auto const &[untypedAttr, elementOrTensorType] :
2238 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2239 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2242 if (isa<ub::PoisonAttr>(untypedAttr)) {
2243 paddingValues.push_back(untypedAttr);
2247 emitOpError(
"expects padding values to be typed attributes or poison");
2251 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2255 if (!parsedAttr || parsedAttr.getType() != elementType) {
2256 auto diag = this->emitOpError(
"expects a padding that parses to ")
2257 << elementType <<
", got " << attr;
2258 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2261 paddingValues.push_back(parsedAttr);
2265 if (attr.getType() != elementType) {
2266 auto diag = this->emitOpError(
"expects a padding value of type ")
2267 << elementType <<
", got " << attr;
2268 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2271 paddingValues.push_back(attr);
2275 TilingInterface paddedOp;
2277 options.setPaddingValues(paddingValues)
2278 .setPaddingSizes(getMixedPaddingSizes())
2279 .setPadToMultipleOf(getPadToMultipleOf());
2284 rewriter, cast<TilingInterface>(targetOp.getOperation()),
options,
2286 if (
failed(maybePaddedOp)) {
2287 auto diag = emitSilenceableError() <<
"failed to pad op";
2288 diag.attachNote(target->
getLoc()) <<
"target op";
2293 paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
2294 padOps.append(newPadOps.begin(), newPadOps.end());
2297 results.
set(cast<OpResult>(getPadded()), paddedOps);
2298 results.
set(cast<OpResult>(getPad()), padOps);
2312 auto targetOps = state.getPayloadOps(getTarget());
2313 auto loopOps = state.getPayloadOps(getLoop());
2314 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2316 <<
"requires exactly one target and one loop handle (got "
2317 << llvm::range_size(targetOps) <<
" and "
2318 << llvm::range_size(loopOps) <<
")";
2321 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2322 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2323 if (!padOp || !loopOp)
2326 FailureOr<linalg::detail::PackingResult> result =
2332 if (result->clonedLoopIvs.empty()) {
2333 transformResults.
set(cast<OpResult>(getPackingLoop()),
2334 {result->hoistedPadOp.getOperation()});
2337 auto outerPackedLoop =
2339 transformResults.
set(cast<OpResult>(getPackingLoop()),
2340 {outerPackedLoop.getOperation()});
2346 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2347 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2349 return emitOpError() <<
"expects transpose to be a permutation, found "
2355 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2365 tensor::PadOp target,
2368 tensor::PadOp hoistedPadOp;
2370 FailureOr<Value> result =
2372 hoistedPadOp, transposeOps);
2373 if (succeeded(result)) {
2383 return emitDefaultSilenceableFailure(target);
2388 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2389 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2391 return emitOpError() <<
"expects transpose to be a permutation, found "
2407 if (!getOperandsToPromote().empty())
2409 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2410 if (getUseFullTilesByDefault())
2412 getUseFullTilesByDefault());
2413 if (getUseOriginalSubviewSize())
2417 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2418 if (!getUseFullTileBuffers().empty())
2420 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2421 if (getAlignment().has_value())
2422 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2423 if (getMemorySpace().has_value())
2424 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2426 if (getMapping().has_value()) {
2428 auto mapping = *getMapping();
2429 if (mapping.size() > 1)
2430 return emitDefaultDefiniteFailure(target);
2432 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2434 if (addressSpace.getAddressSpace() ==
2435 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2442 }
else if (addressSpace.getAddressSpace() ==
2443 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2451 return emitDefaultDefiniteFailure(target);
2456 return emitDefaultDefiniteFailure(target);
2459 FailureOr<LinalgOp> res =
promoteSubViews(rewriter, target, promotionOptions);
2461 return emitDefaultDefiniteFailure(target);
2474 auto payload = state.getPayloadOps(getTarget());
2478 if (target->getNumOperands() > 0)
2481 target->getNumRegions() > 0)
2483 <<
"expected target that is isolated from above";
2487 Operation *pattern = &getBodyRegion().front().front();
2490 if (getOperation()->isAncestor(target))
2495 replacements.push_back(replacement);
2497 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2501 void transform::ReplaceOp::getEffects(
2509 if (!getBodyRegion().hasOneBlock())
2510 return emitOpError() <<
"expected one block";
2511 if (std::distance(getBodyRegion().front().begin(),
2512 getBodyRegion().front().end()) != 1)
2513 return emitOpError() <<
"expected one operation in block";
2514 Operation *replacement = &getBodyRegion().front().front();
2517 <<
"expected replacement without operands";
2521 <<
"expect op that is isolated from above";
2539 target.createFlatListOfOperandDims(b, loc);
2540 AffineMap map = target.getShapesToLoopsMap();
2555 FailureOr<scf::SCFTilingResult> maybeTilingResult =
tileUsingSCF(
2556 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2557 if (
failed(maybeTilingResult))
2558 return emitDefaultDefiniteFailure(target);
2560 if (target->getNumResults())
2561 rewriter.
replaceOp(target, maybeTilingResult->replacements);
2565 results.
reserve(maybeTilingResult->tiledOps.size());
2566 for (
Operation *tiled : maybeTilingResult->tiledOps)
2580 for (
Operation *target : state.getPayloadOps(getTarget())) {
2581 auto tilingOp = dyn_cast<TilingInterface>(*target);
2584 emitSilenceableError()
2585 <<
"expected the payload to implement TilingInterface";
2586 diag.attachNote(target->getLoc()) <<
"payload op";
2590 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2592 if (
failed(generatedLoops))
2593 return emitDefaultDefiniteFailure(target);
2594 for (scf::ForOp &loop : *generatedLoops) {
2595 loops.push_back(loop.getOperation());
2599 results.
set(cast<OpResult>(getResult()), loops);
2608 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2613 FailureOr<Operation *> maybeResult =
2615 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2616 [&rewriter](
auto op) {
2620 return emitDefaultSilenceableFailure(target);
2634 llvm::to_vector(state.getPayloadOps(getTarget()));
2636 bool isMultiwaySplit = getMultiway();
2638 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2640 <<
"requires exactly one target when "
2641 "multiway split is enabled (got "
2642 << llvm::range_size(payload) <<
")";
2647 if (!isMultiwaySplit)
2648 chunkSizes.reserve(payload.size());
2650 if (getDynamicChunkSizes()) {
2652 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2653 chunkSizes = llvm::to_vector(llvm::map_range(
2654 state.getPayloadOps(getDynamicChunkSizes()), [&](
Operation *op) {
2657 diag = emitSilenceableError()
2658 <<
"expected dynamic split point handle to point to a "
2659 "single-result index-typed op";
2660 diag.attachNote(op->getLoc()) <<
"dynamic split point";
2665 chunkSizes = llvm::to_vector(
2666 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2669 if (
diag.isSilenceableFailure())
2674 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2676 <<
"expected the dynamic split point handle to point to as "
2678 << chunkSizes.size() <<
") as the target handle ("
2679 << payload.size() <<
")";
2682 chunkSizes.resize(payload.size(),
2686 auto checkStructuredOpAndDimensions =
2689 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2690 diag.attachNote(loc) <<
"target op";
2694 if (getDimension() >= linalgOp.getNumLoops()) {
2695 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2696 <<
" does not exist in target op";
2697 diag.attachNote(loc) <<
"target op";
2703 auto checkFailureInSplitting =
2707 diag.attachNote(loc) <<
"target op";
2714 if (isMultiwaySplit) {
2717 TilingInterface head, tail;
2720 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2724 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2725 if (
diag.isSilenceableFailure())
2731 target = tail.getOperation();
2736 linalgOp = cast<LinalgOp>(target);
2741 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2742 getDimension(), chunkSize);
2746 checkFailureInSplitting(!head && !tail, loc);
2747 if (
diag.isDefiniteFailure())
2750 opList.push_back(head.getOperation());
2755 opList.push_back(tail.getOperation());
2761 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2764 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2766 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2768 if (
diag.isSilenceableFailure())
2772 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2773 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2774 getDimension(), std::get<1>(pair));
2778 checkFailureInSplitting(!first.back() && !second.back(), loc);
2783 if (!second.back()) {
2784 noSecondPart = target;
2789 if (second.size() != first.size() && !second.empty()) {
2790 auto diag = emitSilenceableError()
2791 <<
"splitting does not produce the second part for a subset "
2794 <<
"expected splitting to produce the second part of all "
2795 "or none of the targets";
2797 <<
"first target with no second part";
2801 opList.append(first);
2803 opList.append(second);
2805 results.
set(cast<OpResult>(getSplitList()), opList);
2809 void SplitOp::getEffects(
2812 if (getDynamicChunkSizes())
2820 IntegerAttr staticChunkSizes;
2826 if (!dynamicPointParseResult.
has_value()) {
2827 int64_t staticChunkSizesValue;
2841 if (dynamicPointParseResult.
has_value()) {
2842 Type ChunkSizesType;
2855 SplitOp::getStaticChunkSizesAttrName(result.
name).getValue(),
2862 printer <<
" " << getTarget() <<
" after ";
2863 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
2864 if (staticChunkSize != ShapedType::kDynamic)
2865 printer << staticChunkSize;
2867 printer << getDynamicChunkSizes();
2870 {getStaticChunkSizesAttrName()});
2871 printer <<
" : " << getTarget().getType();
2872 if (staticChunkSize == ShapedType::kDynamic)
2873 printer <<
", " << getDynamicChunkSizes().getType();
2877 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2878 (getDynamicChunkSizes() ==
nullptr)) {
2879 return emitOpError() <<
"expects either a dynamic or a static split "
2880 "point to be provided";
2889 void transform::SplitReductionOp::build(
2891 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
2892 bool useScalingAlgorithm,
bool useAlloc) {
2898 SplitReductionOp::getInsertSplitDimensionAttrName(result.
name),
2900 if (innerParallel) {
2901 result.
addAttribute(SplitReductionOp::getInnerParallelAttrName(result.
name),
2904 if (useScalingAlgorithm) {
2906 SplitReductionOp::getUseScalingAlgorithmAttrName(result.
name),
2914 result.
addTypes({resultType, resultType, resultType, resultType});
2923 unsigned(getInsertSplitDimension()),
2924 bool(getInnerParallel())};
2927 FailureOr<SplitReductionResult> splitResult =
2928 (getUseScalingAlgorithm())
2932 return emitDefaultDefiniteFailure(target);
2934 results.
push_back(splitResult->initOrAlloc);
2936 results.
push_back(splitResult->splitLinalgOp);
2937 results.
push_back(splitResult->resultCombiningLinalgOp);
2945 void transform::TileReductionUsingForOp::build(
2956 build(builder, result,
2960 staticTileSizesAttr);
2969 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2970 if (!partialReductionOp) {
2973 "Operation should implement PartialReductionOpInterface");
2977 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
2978 if (reductionDims.empty()) {
2979 for (
auto [idx, iteratorType] :
2981 if (iteratorType == utils::IteratorType::reduction)
2982 reductionDims.push_back(idx);
2988 options.setReductionTilingStrategy(
2991 options.setReductionDims(reductionDims);
2992 FailureOr<scf::SCFTilingResult> result =
2997 "failed to tile using partial reduction");
2999 rewriter.
replaceOp(target, result->replacements);
3000 for (
Value initValue : result->initialValues)
3002 for (
auto parallelTiledOp : result->tiledOps)
3004 for (
auto mergeOp : result->mergeOps)
3006 results.
push_back(result->loops.front());
3014 void transform::TileReductionUsingForallOp::build(
3017 ArrayAttr mapping) {
3027 build(builder, result,
3031 staticNumThreadsAttr,
3032 staticTileSizesAttr,
3042 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3043 if (!partialReductionOp) {
3046 "Operation should implement PartialReductionOpInterface");
3055 options.setReductionTilingStrategy(
3057 if (!getNumThreads().empty()) {
3058 options.setNumThreads(numThreads);
3060 options.setTileSizes(tileSizes);
3062 if (
auto mapping = getMapping()) {
3063 options.setMapping(mapping.value().getValue());
3066 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3067 if (reductionDims.empty()) {
3068 for (
auto [idx, iteratorType] :
3070 if (iteratorType == utils::IteratorType::reduction)
3071 reductionDims.push_back(idx);
3074 options.setReductionDims(reductionDims);
3075 FailureOr<scf::SCFTilingResult> result =
3079 auto diag = emitSilenceableError() <<
"could not tile reduction";
3082 rewriter.
replaceOp(target, result->replacements);
3084 for (
Value initValue : result->initialValues)
3086 for (
auto parallelTiledOp : result->tiledOps)
3088 for (
auto mergeOp : result->mergeOps)
3090 results.
push_back(result->loops.front());
3104 llvm::to_vector(state.getPayloadOps(getTarget()));
3106 if (!llvm::hasSingleElement(targetOps)) {
3108 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
3113 auto linalgOp = dyn_cast<LinalgOp>(target);
3114 auto tileableOp = dyn_cast<TilingInterface>(target);
3119 OpBuilder builder(linalgOp.getContext());
3121 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
3122 if (linalgOp.hasDynamicShape()) {
3123 auto diag = emitSilenceableError()
3124 <<
"cannot compute parametric tile sizes for dynamically "
3125 "shaped payload op";
3126 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
3130 FailureOr<StaticContinuousTileSizeSpecification> spec =
3134 return emitSilenceableError()
3135 <<
"failed to compute multi-size tiling sizes";
3140 for (
auto &&[tileSize, tripCount] :
3141 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3142 chunkSizes.push_back(tileSize * tripCount);
3145 return llvm::map_to_vector(values, [&](int64_t value) ->
Attribute {
3150 getI64AttrsFromI64(spec->tileSizes));
3151 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
3152 getI64AttrsFromI64(chunkSizes));
3160 unsigned dimension = getDimension();
3163 builder, tileableOp, dimension, targetSize,
true);
3165 return emitSilenceableError() <<
"could not generate tile size computation";
3177 for (
auto &&[tileSize, tripCount] :
3178 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3179 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3180 chunkSizes.push_back(splitPoint);
3184 return llvm::map_to_vector(values, [&](
Value value) ->
Operation * {
3190 getDefiningOps(spec->tileSizes));
3191 transformResults.
set(cast<OpResult>(getChunkSizes()),
3192 getDefiningOps(chunkSizes));
3200 return emitOpError() <<
"expects all results type to be the same";
3206 void transform::ContinuousTileSizesOp::getEffects(
3224 Type &tileSizesType,
3225 Type &chunkSizesType) {
3226 FunctionType funcType;
3231 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3232 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
3233 "argument and one result";
3235 targetType = funcType.getInput(0);
3236 tileSizesType = chunkSizesType = funcType.getResult(0);
3245 void transform::TileUsingForOp::build(
3250 return build(builder, result, loopTypes,
3254 interchange, scalableSizes);
3257 void transform::TileUsingForOp::build(
3261 build(builder, result, target,
3263 interchange, scalableSizes);
3266 void transform::TileUsingForOp::build(
3273 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3277 void transform::TileUsingForOp::build(
3289 unsigned numExpectedLoops =
3290 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3292 resultTypes.reserve(numExpectedLoops);
3293 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3294 "expected one loop type or as many as loops");
3295 if (loopTypes.size() == 1)
3296 resultTypes.append(numExpectedLoops, loopTypes[0]);
3298 llvm::append_range(resultTypes, loopTypes);
3300 if (scalableSizes.has_value())
3301 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3302 build(builder, result, target.
getType(),
3306 staticTileSizesAttr,
3308 expandedScalableSizes);
3313 return emitOpError(
"expected same number of sizes (")
3315 << getScalableSizes().size() <<
")";
3317 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3318 if (getLoops().size() != numExpectedLoops)
3319 return emitOpError(
"expected number of loops to tile (")
3320 << numExpectedLoops <<
") to match number of `loops` results ("
3321 << getLoops().size() <<
")";
3332 llvm::to_vector(state.getPayloadOps(getTarget()));
3338 if (isa<ParamType>(transformValue.getType())) {
3339 dynamicSizeProducers.push_back({});
3341 paramSizes.push_back(
3342 llvm::to_vector(llvm::map_range(params, [](
Attribute attr) {
3343 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3346 if (paramSizes.back().size() != targets.size()) {
3348 emitSilenceableError()
3349 <<
"expected as many parameter values ("
3350 << dynamicSizeProducers.back().size() <<
") as target ops ("
3351 << targets.size() <<
")";
3352 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
3358 paramSizes.push_back({});
3359 dynamicSizeProducers.push_back(
3360 llvm::to_vector(state.getPayloadOps(transformValue)));
3362 if (dynamicSizeProducers.back().size() != targets.size()) {
3364 emitSilenceableError()
3365 <<
"expected as many dynamic size-producing operations ("
3366 << dynamicSizeProducers.back().size() <<
") as target ops ("
3367 << targets.size() <<
")";
3368 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3372 for (
Operation *op : dynamicSizeProducers.back()) {
3379 emitSilenceableError() <<
"expected sizes to be produced by ops "
3380 "with a single index-type result";
3381 diag.attachNote(op->
getLoc()) <<
"size producer op";
3382 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3389 loops.resize(getLoops().size());
3390 auto scalableSizes = getScalableSizes();
3392 auto tilingInterface = dyn_cast<TilingInterface>(op);
3393 if (!tilingInterface) {
3395 emitSilenceableError()
3396 <<
"only ops implementing TilingInterface are supported";
3397 diag.attachNote(op->
getLoc()) <<
"target op";
3400 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3402 emitSilenceableError()
3403 <<
"too many tiles provided, expected at most "
3404 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
3405 << tileSizes.size();
3406 diag.attachNote(op->
getLoc()) <<
"target op";
3411 if (tileSizes.empty()) {
3420 sizes.reserve(tileSizes.size());
3421 unsigned dynamicIdx = 0;
3424 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3425 if (scalableSizes[ofrIdx]) {
3427 b, getLoc(), cast<IntegerAttr>(attr).getInt());
3429 vector::VectorScaleOp::create(b, getLoc(), b.
getIndexType());
3431 arith::MulIOp::create(b, getLoc(), val, vscale).getResult());
3433 sizes.push_back(attr);
3440 assert((dynamicSizes.empty() ^ params.empty()) &&
3441 "expected either dynamic sizes or parameters");
3442 if (!params.empty()) {
3445 sizes.push_back(dynamicSizes[index]->getResult(0));
3453 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3454 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3455 if (
failed(maybeTilingResult))
3458 rewriter.
replaceOp(op, maybeTilingResult->replacements);
3460 tiled.append(maybeTilingResult->tiledOps);
3462 loops[en2.index()].push_back(en2.value());
3465 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3467 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3476 results.reserve(tileSizes.size());
3477 unsigned dynamicPos = 0;
3479 for (int64_t size : tileSizes) {
3480 if (size == ShapedType::kDynamic) {
3481 results.push_back(dynamic[dynamicPos++]);
3489 void transform::TileUsingForOp::getEffects(
3501 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3505 ArrayAttr mapping) {
3506 return build(builder, result,
3514 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3518 ArrayAttr mapping) {
3528 build(builder, result,
3529 TypeRange{operationType, operationType},
3536 staticTileSizesAttr,
3540 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3544 ArrayAttr mapping) {
3545 return build(builder, result, target,
3550 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3554 ArrayAttr mapping) {
3565 build(builder, result,
3566 TypeRange{operationType, operationType},
3572 staticNumThreadsAttr,
3585 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3587 for (
auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3589 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3590 normalizedUbs.push_back(normalizedUb);
3592 return normalizedUbs;
3608 for (
auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3611 denormalizedIvs.push_back(
3614 return denormalizedIvs;
3625 scf::ForallOp loop) {
3642 auto normalizedForallOp = scf::ForallOp::create(
3643 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3644 loop.getOutputs(), loop.getMapping(),
3647 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3649 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3654 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3655 normalizedForallOp.getRegionIterArgs().end());
3656 Block *origLoopBlock = loop.getBody();
3657 rewriter.
mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3659 rewriter.
replaceOp(loop, normalizedForallOp);
3660 return normalizedForallOp;
3665 TransformOpInterface transformOp,
Operation *target,
3670 auto tileableOp = dyn_cast<TilingInterface>(target);
3673 transformOp.emitSilenceableError()
3674 <<
"only TilingInterface ops are supported";
3675 diag.attachNote(target->
getLoc()) <<
"target op";
3681 if (!mixedNumThreads.empty()) {
3682 options.setNumThreads(mixedNumThreads);
3684 options.setTileSizes(mixedTileSizes);
3687 options.setMapping(mapping.value().getValue());
3689 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3692 if (
failed(maybeTilingResult))
3693 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3695 rewriter.
replaceOp(tileableOp, maybeTilingResult->replacements);
3697 tilingResult = *maybeTilingResult;
3699 if (mixedNumThreads.empty()) {
3700 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.
loops.front());
3703 scf::ForallOp normalizedForallOp =
3705 tilingResult.
loops.front() = normalizedForallOp;
3715 auto transformOp = cast<TransformOpInterface>(getOperation());
3724 getPackedNumThreads()
3726 state, transformOp, mixedNumThreads, getPackedNumThreads())
3728 state, transformOp, mixedNumThreads, getMixedNumThreads());
3732 status = getPackedTileSizes()
3734 state, transformOp, mixedTileSizes, getPackedTileSizes())
3736 state, transformOp, mixedTileSizes, getMixedTileSizes());
3740 for (
Operation *target : state.getPayloadOps(getTarget())) {
3743 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3744 getMapping(), tilingResult);
3745 if (!
diag.succeeded())
3747 tileOps.push_back(tilingResult.
loops.front());
3748 tiledOps.append(tilingResult.
tiledOps);
3751 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3752 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3757 void transform::TileUsingForallOp::getEffects(
3770 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3779 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3780 static_cast<int>(getPackedNumThreads() !=
Value());
3781 if (numThreadsSpec > 1)
3783 "num_threads and packed_num_threads are mutually exclusive");
3784 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3785 static_cast<int>(getPackedTileSizes() !=
Value());
3786 if (tileSizesSpec > 1)
3788 "tile_sizes and packed_tile_sizes are mutually exclusive");
3789 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3790 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
3791 "must be specified");
3799 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3801 bool foldTypeExtensionsIntoContract,
bool vectorizePadding,
3802 bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
3804 if (foldTypeExtensionsIntoContract) {
3806 VectorizeChildrenAndApplyPatternsOp::
3807 getFoldTypeExtensionsIntoContractAttrName(result.
name),
3810 if (vectorizePadding) {
3812 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3816 if (vectorizeExtract) {
3818 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3822 if (flatten1DDepthwiseConv) {
3824 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3835 explicit VectorizationPattern(
MLIRContext *context,
3836 bool vectorizeExtract =
false,
3837 bool flattenConv =
false)
3839 vectorizeNDExtract(vectorizeExtract),
3840 flatten1DDepthwiseConv(flattenConv) {}
3841 LogicalResult matchAndRewrite(
Operation *op,
3845 "Unsupported Op, cannot vectorize");
3846 FailureOr<VectorizationResult> vectorResults =
3848 {}, vectorizeNDExtract,
3849 flatten1DDepthwiseConv);
3850 if (
failed(vectorResults))
3852 rewriter.
replaceOp(op, vectorResults->replacements);
3859 bool vectorizeNDExtract =
false;
3863 bool flatten1DDepthwiseConv =
false;
3868 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3873 auto diag = this->emitOpError(
"requires isolated-from-above targets");
3874 diag.attachNote(target->
getLoc()) <<
"non-isolated target";
3880 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3881 getFlatten_1dDepthwiseConv());
3883 if (!getDisableTransferPermutationMapLoweringPatterns())
3886 if (!getDisableMultiReductionToContractPatterns())
3894 vector::TransferReadOp::getCanonicalizationPatterns(
patterns, ctx);
3895 vector::TransferWriteOp::getCanonicalizationPatterns(
patterns, ctx);
3900 if (getFoldTypeExtensionsIntoContract())
3903 if (getVectorizePadding()) {
3915 return emitDefaultDefiniteFailure(target);
3929 auto targets = state.getPayloadOps(getTarget());
3930 if (std::empty(targets))
3932 auto transformOp = cast<TransformOpInterface>(getOperation());
3935 state, transformOp, getMixedVectorSizes(), vectorSizes);
3943 <<
"Unsupported Op, cannot vectorize";
3945 FailureOr<VectorizationResult> vectorResults =
3947 getVectorizeNdExtract().value_or(
false),
3949 getAssumeDynamicDimsMatchVecSizes().value_or(
false),
3950 getCreateNamedContraction().value_or(
false));
3951 if (
failed(vectorResults)) {
3953 <<
"Attempted to vectorize, but failed";
3955 rewriter.
replaceOp(target, vectorResults->replacements);
3961 void transform::VectorizeOp::getEffects(
3970 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3974 if (getStaticVectorSizes().size() != getScalableSizes().size())
3975 return emitOpError(
"expected same number of vector sizes (")
3976 << getStaticVectorSizes().size() <<
") and scalable sizes ("
3977 << getScalableSizes().size() <<
")";
3986 transform::HoistRedundantVectorTransfersOp::applyToOne(
4003 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4022 auto maybeTransformed =
4025 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4028 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4031 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4034 .Case([&](linalg::Conv2DNchwFchwOp op) {
4040 if (
failed(maybeTransformed))
4041 return emitDefaultSilenceableFailure(target);
4043 results.
push_back(maybeTransformed->first);
4045 results.
push_back(maybeTransformed->second);
4060 <<
"only elementwise flattening is supported";
4063 if (target.getNumLoops() <= 1) {
4070 std::iota(reassociation.begin(), reassociation.end(), 0);
4071 auto maybeFlattened =
4073 if (
failed(maybeFlattened))
4075 <<
"attempted to flatten, but failed";
4076 results.
push_back(maybeFlattened->collapsedOp);
4077 rewriter.
replaceOp(target, maybeFlattened->results);
4090 auto maybeTransformed =
4092 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4095 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4101 if (
failed(maybeTransformed))
4102 return emitDefaultSilenceableFailure(target);
4117 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4118 auto maybeTransformed =
4120 .Case([&](linalg::MatmulOp op) {
4123 .Case([&](linalg::BatchMatmulOp op) {
4126 .Default([&](
Operation *op) {
return failure(); });
4127 if (
failed(maybeTransformed))
4137 template <
typename OpTy>
4141 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4142 tensor::ParallelInsertSliceOp>() &&
4145 if (
auto copySource =
4146 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4154 if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
4157 Value extracted = tensor::ExtractSliceOp::create(
4158 rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
4159 target.getMixedSizes(), target.getMixedStrides());
4160 Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
4161 target.getSource(), extracted)
4166 target, copied, target.getDest(), target.getMixedOffsets(),
4167 target.getMixedSizes(), target.getMixedStrides());
4179 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4180 return doit(rewriter, target, results, state);
4181 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4182 return doit(rewriter, target, results, state);
4185 emitSilenceableError()
4186 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4187 diag.attachNote(targetOp->
getLoc()) <<
"target op";
4200 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4202 emitSilenceableError()
4203 <<
"only linalg.copy and tensor.pad target ops are supported";
4204 diag.attachNote(target->
getLoc()) <<
"target op";
4207 assert(target->
getNumResults() == 1 &&
"expected single result");
4209 if (!resultShapedType.hasStaticShape()) {
4211 emitSilenceableError()
4212 <<
"only statically sized ops of rank <= 3 are supported";
4213 diag.attachNote(target->
getLoc()) <<
"target op";
4218 int64_t desiredBitAlignment = getDesiredBitAlignment();
4219 int64_t eltBitwidth =
4220 resultShapedType.getElementType().getIntOrFloatBitWidth();
4221 if (desiredBitAlignment % eltBitwidth != 0) {
4222 desiredBitAlignment = eltBitwidth;
4227 getTotalNumThreads(),
4228 desiredBitAlignment,
4229 resultShapedType.getShape(),
4232 resultShapedType.getElementType().getIntOrFloatBitWidth());
4233 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4235 emitSilenceableError()
4236 <<
"too few threads to map copy op to threads on the most minor "
4237 "dimension, given alignment and vector size constraints, try "
4238 "smaller tile size of mapping to more threads";
4239 diag.attachNote(target->
getLoc()) <<
"target op";
4255 if (!
diag.succeeded())
4259 for (
auto op : tilingResult.
tiledOps)
4273 FailureOr<Operation *> maybeTransformed = failure();
4275 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4280 .Default([&](
Operation *op) {
return false; });
4283 return emitSilenceableError()
4284 <<
"this operation is not supported to convert to Winograd Conv2D";
4287 if (
failed(maybeTransformed)) {
4288 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
4300 FailureOr<Operation *> maybeTransformed = failure();
4303 .Case([&](linalg::WinogradFilterTransformOp op) {
4307 .Case([&](linalg::WinogradInputTransformOp op) {
4311 .Case([&](linalg::WinogradOutputTransformOp op) {
4315 .Default([&](
Operation *op) {
return false; });
4319 emitSilenceableError()
4320 <<
"this operation is not supported to decompose into other operations";
4321 diag.attachNote(target->
getLoc()) <<
"target op";
4325 if (
failed(maybeTransformed)) {
4327 emitSilenceableError() <<
"decompose Winograd operations failed";
4328 diag.attachNote(target->
getLoc()) <<
"target op";
4336 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4338 #define GET_OP_CLASSES
4339 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
This class represents a saved insertion point.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getOpResult(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
@ MaterializeInDestination
Transformation to drop unit-extent dimensions from linalg.generic operations.
Vectorization pattern for memref::CopyOp.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Match and rewrite for the pattern:
Match and rewrite for the pattern:
@ BufferizationMaterializeInDestination
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.