44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/ScopeExit.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include <type_traits>
55 #define DEBUG_TYPE "linalg-transforms"
56 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
57 #define DBGSNL() (llvm::dbgs() << "\n")
58 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
65 template <
typename PatternTy,
typename... Args>
68 using OpTy =
typename llvm::function_traits<
69 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
70 auto op = dyn_cast<OpTy>(operation);
75 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
80 explicit TrivialPatternRewriter(
MLIRContext *context)
83 TrivialPatternRewriter rewriter(operation->
getContext());
84 rewriter.setInsertionPoint(operation);
85 auto result = pattern.returningMatchAndRewrite(op, rewriter);
88 return cast<LinalgOp>(result->getOperation());
98 if (
auto attr = dyn_cast<Attribute>(ofr)) {
99 if (!isa<IntegerAttr>(attr))
100 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
101 result.push_back(ofr);
105 Value transformValue = cast<Value>(ofr);
106 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
108 if (params.size() != 1)
109 return transformOp.emitDefiniteFailure()
110 <<
"requires exactly one parameter associated";
111 result.push_back(params[0]);
115 auto payloadOps = state.getPayloadOps(transformValue);
116 if (!llvm::hasSingleElement(payloadOps)) {
118 transformOp.emitSilenceableError()
119 <<
"handle must be mapped to exactly one payload op";
121 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
128 transformOp.emitSilenceableError()
129 <<
"payload op must have exactly 1 index result";
149 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
151 for (
auto param : params) {
152 if (!isa<IntegerAttr>(param))
153 return transformOp.emitDefiniteFailure()
154 <<
"expected the parameter to be associated with an integer "
156 result.push_back(param);
161 for (
Operation *op : state.getPayloadOps(packedHandle)) {
162 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
164 transformOp.emitSilenceableError()
165 <<
"payload op must have exactly 1 index result";
166 diag.attachNote(op->getLoc())
167 <<
"has " << op->getNumResults() <<
" results";
170 result.push_back(op->getResult(0));
184 if (
auto attr = dyn_cast<Attribute>(paramOrHandle)) {
185 reified.push_back(cast<IntegerAttr>(attr).getInt());
187 }
else if (isa<ParamType>(cast<Value>(paramOrHandle).
getType())) {
189 if (params.size() != 1)
190 return transformOp.emitSilenceableError() <<
"expected a single param";
192 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
196 Value handle = cast<Value>(paramOrHandle);
197 if (!isa<TransformHandleTypeInterface>(handle.
getType()))
198 return transformOp.emitSilenceableError() <<
"unexpected value handle";
199 auto payload = state.getPayloadOps(handle);
200 if (!llvm::hasSingleElement(payload))
201 return transformOp.emitSilenceableError()
202 <<
"requires param or handle that is mapped to 1 payload op";
204 Operation *paramOrHandlePayloadOp = *payload.begin();
207 return transformOp.emitSilenceableError()
208 <<
"requires param or handle to be result of op with 1 index "
214 return transformOp.emitSilenceableError()
215 <<
"requires param or handle to be the result of a constant like "
218 reified.push_back(attr.getInt());
227 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
232 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
237 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
242 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
248 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
251 options.rankReductionStrategy =
256 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
261 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
266 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
271 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
276 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
285 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
290 resultTypes.push_back(b.
getType<transform::AnyValueType>());
291 resultTypes.push_back(b.
getType<transform::AnyOpType>());
292 return build(b, result,
298 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
301 int64_t memorySpace) {
303 resultTypes.push_back(b.
getType<transform::AnyValueType>());
304 resultTypes.push_back(b.
getType<transform::AnyOpType>());
305 return build(b, result,
321 void notifyOperationInserted(
Operation *op,
323 ForwardingListener::notifyOperationInserted(op, previous);
325 if (previous.
isSet())
327 auto inserted = newOps.insert(op);
329 assert(inserted.second &&
"expected newly created op");
332 void notifyOperationErased(
Operation *op)
override {
333 ForwardingListener::notifyOperationErased(op);
347 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
348 NewOpsListener newOpsListener(previousListener);
352 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
355 }
else if (getMemcpyOp() ==
"memref.copy") {
358 }
else if (getMemcpyOp() ==
"linalg.copy") {
362 llvm_unreachable(
"invalid memcpy op");
364 if (getAllocOp() ==
"memref.alloc") {
367 }
else if (getAllocOp() ==
"memref.alloca") {
371 llvm_unreachable(
"invalid alloc op");
373 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
374 options.emitDealloc = getEmitDealloc();
378 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
380 for (
Operation *op : state.getPayloadOps(getTarget())) {
385 <<
"failed to bufferize operation";
386 diag.attachNote(op->
getLoc()) <<
"target payload op";
389 allocatedBuffers.push_back(buffer);
393 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
394 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
398 void transform::BufferizeToAllocationOp::getEffects(
400 if (getBufferizeDestinationOnly()) {
412 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
413 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
414 return emitOpError() <<
"unsupported memcpy op";
415 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
416 return emitOpError() <<
"unsupported alloc op";
429 #define DOWNSCALE(trans) \
431 FailureOr<LinalgOp> res = tryApply<trans>(target); \
432 if (succeeded(res)) { \
433 results.push_back(*res); \
434 return DiagnosedSilenceableFailure::success(); \
438 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
439 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
452 #undef DOWNSCALE_NORMAL
453 #undef DOWNSCALE_CALL
455 return emitDefaultSilenceableFailure(target);
469 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
470 if (!decomposableOp) {
472 "payload is not a decomposable op"));
473 return emitDefaultSilenceableFailure(target);
476 FailureOr<SmallVector<Value>> maybeNewResults =
477 decomposableOp.decomposeOperation(rewriter);
478 if (failed(maybeNewResults))
479 return emitDefaultSilenceableFailure(target);
481 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
482 for (
Value val : *maybeNewResults) {
483 Operation *definition = val.getDefiningOp();
494 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
501 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
505 options.allowReturnAllocsFromLoops =
true;
507 for (
Operation *target : state.getPayloadOps(getTarget())) {
511 <<
"failed to analyze op";
513 rewriter, target, state)))
515 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
526 template <
typename Range>
530 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
536 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
537 if (!tilingInterfaceOp)
538 return transformOp->
emitError(
"only TilingInterface ops are supported");
541 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
542 applyFn(tilingInterfaceOp);
543 if (failed(tiledResults))
548 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
549 for (
Operation *toReplace : opsToReplace) {
550 for (
OpResult res : toReplace->getResults())
551 if (
auto replacement = tiledResults->replacements.lookup(res))
553 if (toReplace->use_empty()) {
559 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
560 assert(tiledResults->loops.size() == numLoops &&
561 "Mismatched number of loops, tile and fuse transform should have "
563 for (
unsigned int i = 0; i < numLoops; ++i)
564 loopOps[i].push_back(tiledResults->loops[i]);
567 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
568 for (
unsigned int i = 0; i < numLoops; ++i)
569 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
581 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
587 tilingOptions = tilingOptions.
setTileSizes(tileSizesOfr);
591 if (getApplyCleanup()) {
594 tensor::ExtractSliceOp::getCanonicalizationPatterns(
patterns, context);
601 rewriter, getOperation(), state.getPayloadOps(getTarget()),
602 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
603 [&](TilingInterface tilingInterfaceOp)
604 -> FailureOr<scf::SCFTileAndFuseResult> {
605 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
614 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
615 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
616 if (!std::is_permutation(sequence.begin(), sequence.end(),
617 permutation.begin(), permutation.end())) {
618 return emitOpError() <<
"expects interchange to be a permutation, found "
619 << getTileInterchange();
624 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
625 if (numExpectedLoops != getNumResults() - 1)
626 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
635 void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
638 Value containingOp) {
641 result.
addTypes({resultType, resultType});
657 (domInfo.
dominates(containingOp, user))) {
658 dominatedUsers.insert(user);
661 if (dominatedUsers.empty())
665 auto forallOp = cast<scf::ForallOp>(containingOp);
671 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
676 newOuts.push_back(outputs[resultNumber]);
679 auto newforallOp = rewriter.
create<scf::ForallOp>(
680 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
681 forallOp.getMixedStep(), newOuts, forallOp.getMapping());
683 newforallOp.getRegion().takeBody(forallOp.getRegion());
688 newforallOp.getBody()->addArgument(newOuts.back().getType(),
689 newOuts.back().getLoc());
690 auto bbArgs = newforallOp.getBody()->getArguments();
698 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
700 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; }));
701 Operation *firstYieldOp = yieldingOps.front();
704 Value dst = newforallOp.getRegionIterArgs().back();
706 rewriter.
create<tensor::ParallelInsertSliceOp>(firstYieldOp->
getLoc(), src,
707 dst, offsets, sizes, strides);
711 newforallOp->getResult(result.index()));
714 newforallOp->getResults().back(),
717 return dominatedUsers.contains(user);
731 destWorklist.push_back(dst);
733 while (!destWorklist.empty()) {
734 Value currentDst = destWorklist.pop_back_val();
738 if (src == currentDst)
743 auto bbArg = dyn_cast<BlockArgument>(currentDst);
747 Block *parentBlock = bbArg.getOwner();
748 assert(parentBlock &&
"unlinked block argument");
751 assert(parentOp &&
"expected block argument with parent operation");
754 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
758 for (
auto innerIterArg : parentLoop.getRegionIterArgs()) {
760 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
761 Value loopBlockArgument =
763 destWorklist.push_back(loopBlockArgument);
776 static std::tuple<SmallVector<Operation *>,
Operation *>
779 LLVM_DEBUG(
DBGS() <<
"Try to fuse a direct extract use\n");
780 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
781 if (!tileableProducer) {
783 <<
"producer is not a TileableInterface: " << *producerOp;
790 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
791 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
792 return sliceOp && containingOp->isProperAncestor(sliceOp);
796 if (it == tileableProducer->getUsers().end()) {
797 diag.attachNote(tileableProducer->getLoc())
798 <<
"could not find fusion opportunity for: " << *tileableProducer;
801 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
814 if (LoopLikeOpInterface containerLoop =
815 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
822 cast<DestinationStyleOpInterface>(
clone).getDpsInitsMutable()) {
823 Value producerOperand =
826 containerLoop.getRegionIterArgs()) {
827 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
828 Value consumerOperand =
829 containerLoop->getOperand(bbArg->getOperandNumber());
831 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
832 initOperandPtr.set(containerIterArg);
838 tileableProducer = dyn_cast<TilingInterface>(
clone);
842 int64_t resultNumber =
843 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
844 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
849 FailureOr<TilingResult> tileAndFuseResult =
850 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
853 if (failed(tileAndFuseResult)) {
854 diag.attachNote(tileableProducer->getLoc())
855 <<
"failed to tile producer op: " << *tileableProducer;
860 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
861 LLVM_DEBUG(
DBGS() <<
"tiledProducer: " << *tiledOp <<
"\n");
866 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
867 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
868 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
869 if (failed(maybeRankReduced)) {
871 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
872 << tileAndFuseResult->tiledValues[0]
873 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
876 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
880 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
881 resultNumber, offsets, sizes);
884 if (dyn_cast<LoopLikeOpInterface>(containingOp))
885 rewriter.
eraseOp(tileableProducer);
887 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
900 LLVM_DEBUG(
DBGS() <<
"Try to fuse an extract use through block argument\n");
902 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
903 if (!tileableProducer) {
905 <<
"producer is not a TileableInterface: " << *producerOp;
910 scf::ForallOp forallOp;
911 auto itProducerUses =
912 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
913 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
917 if (!forallOp || forallOp != containingOp) {
918 diag.attachNote(tileableProducer->getLoc())
919 <<
"could not find a use by the containing op: " << *tileableProducer;
934 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
935 return sliceOp && containingOp->isProperAncestor(sliceOp);
939 if (itBBArgUsers == bbArg.
getUsers().end()) {
941 <<
"could not find fusion opportunity for bbArg: " << bbArg;
944 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
952 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
953 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
958 rewriter, tileableProducer->getLoc(), tileableProducer,
959 destinationTensors))) {
960 diag.attachNote(tileableProducer->getLoc())
961 <<
"failed to get destination tensors for: " << *tileableProducer;
966 bvm.
map(destinationTensors[resultNumber], bbArg);
967 auto tileableProducerClone =
968 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
970 llvm::make_scope_exit([&]() { rewriter.
eraseOp(tileableProducerClone); });
973 FailureOr<TilingResult> tileAndFuseResult =
974 tileableProducerClone.generateResultTileValue(
975 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
976 sliceOpToTile.getMixedSizes());
977 if (failed(tileAndFuseResult)) {
978 diag.attachNote(tileableProducer->getLoc())
979 <<
"failed to tile producer op: " << *tileableProducer;
984 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
985 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
986 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
987 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
988 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
993 destinationTensors.front());
996 return tileAndFuseResult->tiledOps;
1002 LLVM_DEBUG(
DBGS() <<
"Try to fuse an use by cloning\n");
1007 for (
OpOperand &use : result.getUses()) {
1009 uses.push_back(&use);
1014 if (containingOp == use.getOwner()) {
1016 <<
"producer op use by containing op cannot be fused by cloning";
1024 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
1033 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
1034 "Parallel insert slice is not a valid clone destination");
1035 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
1036 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
1040 fusedOp = rewriter.
clone(*producerOp);
1042 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1047 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1057 auto producerOps = state.getPayloadOps(getProducerOp());
1058 auto containingOps = state.getPayloadOps(getContainingOp());
1059 if (!llvm::hasSingleElement(containingOps)) {
1061 <<
"requires exactly one containing_op handle (got "
1062 << llvm::range_size(containingOps) <<
")";
1064 Operation *containingOp = *containingOps.begin();
1067 if (std::empty(producerOps)) {
1069 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1076 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1077 for (
const auto &it :
enumerate(remainingProducers)) {
1080 int64_t numUsesInContainingOp =
1082 return containingOp->isAncestor(op);
1087 if (numUsesInContainingOp > 0) {
1088 if (numUsesInContainingOp == 1)
1089 remainingProducers.erase(remainingProducers.begin() + it.index());
1096 while (!remainingProducers.empty()) {
1097 auto nextProducer = getNextProducer();
1098 if (failed(nextProducer)) {
1100 <<
"could not find next producer to fuse into container";
1101 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
1109 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
1116 auto [tiledOps, newContainingOp] =
1118 if (!tiledOps.empty()) {
1119 LLVM_DEBUG(
DBGS() <<
"\nFused a direct extract use\n" << *containingOp);
1120 fusedOps.append(tiledOps);
1121 if (newContainingOp) {
1129 LogicalResult replacementStatus =
1132 (void)replacementStatus;
1133 assert(succeeded(replacementStatus) &&
1134 "unable to update transform state mapping");
1135 rewriter.
eraseOp(containingOp);
1136 containingOp = newContainingOp;
1143 rewriter,
diag, producerOp, containingOp);
1144 if (!tiledContainingOpOperand.empty()) {
1145 LLVM_DEBUG(
DBGS() <<
"\nFused an extract use through block argument\n"
1147 fusedOps.append(tiledContainingOpOperand);
1154 LLVM_DEBUG(
DBGS() <<
"\nFused an use by cloning\n" << *containingOp);
1155 fusedOps.push_back(cloned);
1161 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1162 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1166 void transform::FuseIntoContainingOp::getEffects(
1184 if (isa<GenericOp>(target)) {
1190 if (succeeded(
generic)) {
1191 results.
push_back(generic->getOperation());
1194 return emitDefaultSilenceableFailure(target);
1207 if (!isa<GenericOp>(target)) {
1212 FailureOr<LinalgOp> named =
1214 if (succeeded(named)) {
1215 results.
push_back(named->getOperation());
1218 return emitDefaultSilenceableFailure(target);
1232 if (interchangeVector.empty()) {
1237 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1238 if (interchangeVector.size() != numLoops) {
1239 return emitSilenceableError()
1240 << getIteratorInterchangeAttrName() <<
" has length ("
1241 << interchangeVector.size()
1242 <<
") different from the number of loops in the target operation ("
1255 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1256 if (!std::is_permutation(sequence.begin(), sequence.end(),
1257 permutation.begin(), permutation.end())) {
1258 return emitOpError()
1259 <<
"expects iterator_interchange to be a permutation, found "
1260 << getIteratorInterchange();
1275 if (!isa<linalg::CopyOp>(targetOp)) {
1277 emitSilenceableError() <<
"only linalg.copy target ops are supported";
1278 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1282 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1283 if (!copyOp.hasPureBufferSemantics()) {
1285 emitSilenceableError()
1286 <<
"cannot transform a linalg.copy on tensors into a memref.copy";
1287 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1293 assert(inputs.size() == 1 &&
"expected linalg copy op with one input");
1294 assert(outputs.size() == 1 &&
"expected memref copy op with one output");
1295 Value input = inputs.front();
1296 Value output = outputs.front();
1301 if (!isa<ShapedType>(input.
getType())) {
1303 emitSilenceableError()
1304 <<
"cannot transform a linalg.copy which input has no shape";
1305 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1310 assert(isa<ShapedType>(output.
getType()));
1312 if (cast<ShapedType>(input.
getType()).getElementType() !=
1313 cast<ShapedType>(output.
getType()).getElementType()) {
1315 emitSilenceableError()
1316 <<
"cannot transform a linalg.copy with different source and "
1317 "destination element types ";
1318 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1339 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1340 FailureOr<LowerPackResult> res =
1341 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1344 <<
"cannot lower to pad + expand + transpose";
1347 transformResults.
push_back(res->expandShapeOp);
1348 transformResults.
push_back(res->transposeOp);
1361 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1362 FailureOr<LowerUnPackOpResult> res =
1363 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1366 emitSilenceableError()
1367 <<
"cannot lower to transpose + collapse + extract";
1368 diag.attachNote(target->getLoc()) <<
"target payload op";
1371 transformResults.
push_back(res->emptyOp);
1372 transformResults.
push_back(res->transposeOp);
1373 transformResults.
push_back(res->collapseShapeOp);
1374 transformResults.
push_back(res->extractSliceOp);
1404 if (getOps().has_value())
1405 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1407 auto payloadOps = state.getPayloadOps(getTarget());
1408 if (!llvm::hasSingleElement(payloadOps)) {
1413 bool incorrectNumOperandTypes =
false;
1420 if (getInterface().has_value()) {
1421 auto iface = getInterface().value();
1422 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1425 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1426 !isa<TilingInterface>(op))
1428 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1429 !isa<LoopLikeOpInterface>(op))
1434 if (getOpAttrs().has_value()) {
1435 DictionaryAttr opAttrs = getOpAttrs().value();
1437 if (attr.getName() == getInterfaceAttrName() ||
1438 attr.getName() == getOpsAttrName())
1440 if (!op->
hasAttr(attr.getName()))
1442 if (op->
getAttr(attr.getName()) != attr.getValue())
1447 if (getFilterResultType().has_value()) {
1448 Type t = getFilterResultType().value();
1453 if (getFilterOperandTypes().has_value()) {
1454 mlir::ArrayAttr types = getFilterOperandTypes().value();
1457 if (types.size() == 1) {
1460 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1461 Type t = cast<::mlir::Type>(typeattr.getValue());
1463 [&](
Type operandType) { return operandType == t; }))
1468 if (types.size() != operandTypes.size()) {
1469 incorrectNumOperandTypes =
true;
1473 for (
auto [attr, operandType] :
1474 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1475 auto typeattr = cast<mlir::TypeAttr>(attr);
1476 Type type = cast<::mlir::Type>(typeattr.getValue());
1478 if (type != operandType)
1489 (*payloadOps.begin())->
walk(matchFun);
1490 if (incorrectNumOperandTypes)
1492 "type, then it must contain as much types as "
1493 "the number of operands in the target ops");
1494 results.
set(cast<OpResult>(getResult()), res);
1509 Type &targetType,
Type &lowSizeType,
1511 Type &splitPointType) {
1512 FunctionType funcType;
1514 if (failed(parser.
parseType<FunctionType>(funcType)))
1517 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1518 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1519 "argument and one result";
1521 targetType = funcType.getInput(0);
1522 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1530 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1531 if (target.hasDynamicShape()) {
1532 auto diag = emitSilenceableError()
1533 <<
"cannot compute parametric tile sizes for dynamically "
1534 "shaped payload op";
1535 diag.attachNote(target->getLoc()) <<
"payload op";
1540 target, getDimension(), getTargetSize(), getDivisor());
1542 return emitSilenceableError()
1543 <<
"failed to compute multi-size tiling sizes";
1546 Builder builder(target.getContext());
1547 results.
assign(llvm::map_range(
1549 spec->lowTileSize * spec->lowTripCount}),
1550 [&builder,
this](int64_t value) {
1562 builder, target, getDimension(), targetSize, divisor);
1564 return emitSilenceableError() <<
"could not generate tile size computation";
1571 {spec->lowTileSize, spec->lowTripCount});
1572 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1573 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1574 assert(lowTileSize && highTileSize && splitPoint &&
1575 "tile sizes are not produced by operations");
1583 void transform::MultiTileSizesOp::getEffects(
1587 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1596 return emitOpError() <<
"expects all results type to be the same";
1616 builder.
getContext(), GenericOp::getOperationName());
1617 build(builder, result,
1626 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1633 auto targetOps = state.getPayloadOps(getTarget());
1635 if (std::empty(targetOps)) {
1636 transformResults.
set(cast<OpResult>(getPackedOp()),
1641 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1642 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1643 return emitSilenceableError()
1644 <<
"requires target to map to exactly 1 LinalgOp (got "
1645 << llvm::range_size(targetOps) <<
")";
1648 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1649 return emitSilenceableError()
1650 <<
"requires number of packed sizes match the number of loops ("
1651 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1658 state, *
this, packedSizes, getMixedPackedSizes());
1661 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1662 if (failed(maybeResult))
1665 transformResults.
set(cast<OpResult>(getPackedOp()),
1666 {maybeResult->packedLinalgOp.getOperation()});
1670 void transform::PackOp::getEffects(
1684 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1685 <<
" is not a valid permutation";
1688 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1689 for (
auto [s, nmo] :
1690 llvm::zip_equal(getMixedMatmulPackedSizes(),
1691 getMatmulPaddedSizesNextMultipleOf())) {
1694 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1695 return emitOpError() <<
"at most one of the packed_size and the "
1696 "padded_sizes_next_multiple_of can be nonzero "
1697 "for the matmul strategy";
1709 for (
Operation *op : state.getPayloadOps(getTarget())) {
1710 auto linalgOp = dyn_cast<LinalgOp>(op);
1721 getMixedMatmulPackedSizes(),
1723 getMatmulPaddedSizesNextMultipleOf(),
1724 getMatmulInnerDimsOrder());
1725 if (succeeded(packResult)) {
1726 results.push_back(packResult->packedLinalgOp);
1729 results.push_back(linalgOp);
1731 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1737 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1741 void transform::PackGreedilyOp::getEffects(
1755 return emitOpError() << getInnerPermAttrName()
1756 <<
" is not a valid permutation";
1759 return emitOpError() << getOuterPermAttrName()
1760 <<
" is not a valid permutation";
1762 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1763 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1764 <<
" or " << getOuterPermAttrName()
1765 <<
" must be specified";
1771 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1781 template <
typename RelayoutOpTy>
1784 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1786 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1787 "applies to only pack or unpack operations");
1788 if (!op || permutation.empty())
1790 size_t innerRank = op.getInnerDimsPos().size();
1791 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1795 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1796 return permutation.size() == op.getSourceRank() &&
1799 return permutation.size() == op.getDestRank() &&
1807 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1808 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1810 if (std::empty(packOrUnpackOps)) {
1811 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1812 transformResults.
set(cast<OpResult>(getPackOp()), {});
1813 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1819 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1820 !llvm::hasSingleElement(linalgOps)) {
1821 return emitSilenceableError()
1822 <<
"requires target to map to exactly 1 "
1823 "packing op and 1 packed op ("
1824 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
1825 << llvm::range_size(linalgOps) <<
")";
1829 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1830 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1831 if ((!packOp && !unPackOp)) {
1832 return emitSilenceableError() <<
"requires target to map to a "
1833 "linalg.pack or linalg.unpack";
1835 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1836 if (!linalgOpTarget)
1837 return emitSilenceableError() <<
"requires a LinalgOp target";
1841 if (packOp && packOp.getResult().hasOneUse())
1842 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1844 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1845 if (linalgOp != linalgOpTarget) {
1847 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
1848 : StringLiteral{
"not produced by the LinalgOp target"};
1849 return emitSilenceableError() << errorMsg;
1855 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
1856 OpOperand *packUse = linalgOp.getDpsInitOperand(
1857 cast<OpResult>(unPackOp.getSource()).getResultNumber());
1859 if (!packOp || !packOp.getResult().hasOneUse())
1860 return emitSilenceableError() <<
"could not find matching pack op";
1864 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1866 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1867 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1868 ? StringLiteral{
"invalid outer_perm"}
1869 : StringLiteral{
"invalid inner_perm"};
1873 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1874 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
1880 assert(packOp && linalgOp &&
"unexpected null op");
1884 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1886 assert(succeeded(res) &&
"unexpected packTranspose failure");
1889 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1890 transformResults.
set(cast<OpResult>(getPackedOp()),
1891 {res->transposedLinalgOp});
1893 transformResults.
set(cast<OpResult>(getUnPackOp()),
1894 {res->transposedUnPackOp});
1896 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1911 StringRef copyBackOp,
1912 bool usePrescribedTensorShapes) {
1922 (padToMultipleOf.empty()
1924 : b.getDenseI64ArrayAttr(padToMultipleOf)),
1925 b.getI64ArrayAttr(nofoldFlags),
1926 b.getArrayAttr(transposePaddings),
1927 b.getStringAttr(copyBackOp),
1929 usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
1937 StringRef copyBackOp,
1938 bool usePrescribedTensorShapes) {
1943 staticPadToMultipleOf);
1950 dynamicPadToMultipleOf,
1951 staticPadToMultipleOf,
1955 usePrescribedTensorShapes);
1958 void PadOp::getEffects(
1968 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1975 auto transformOp = cast<TransformOpInterface>(getOperation());
1978 for (
Operation *target : state.getPayloadOps(getTarget())) {
1979 auto linalgTarget = dyn_cast<LinalgOp>(target);
1980 if (!linalgTarget) {
1981 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
1982 diag.attachNote(target->
getLoc()) <<
"target op";
1988 for (int64_t packPadding :
1989 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1990 nofoldFlags.push_back(
static_cast<bool>(packPadding));
1994 for (
auto const &it :
1995 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1996 auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1998 emitOpError(
"expects padding values to be typed attributes");
2003 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2007 if (!parsedAttr || parsedAttr.getType() != elementType) {
2008 auto diag = this->emitOpError(
"expects a padding that parses to ")
2009 << elementType <<
", got " << std::get<0>(it);
2010 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2013 paddingValues.push_back(parsedAttr);
2017 if (attr.getType() != elementType) {
2018 auto diag = this->emitOpError(
"expects a padding value of type ")
2019 << elementType <<
", got " << attr;
2020 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2023 paddingValues.push_back(attr);
2028 for (
Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2029 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2030 cast<ArrayAttr>(transposeVector)));
2035 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2039 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2042 if (padToMultipleOf.empty())
2046 options.padToMultipleOf = padToMultipleOf;
2047 options.paddingValues = paddingValues;
2048 options.nofoldFlags = nofoldFlags;
2049 if (getCopyBackOp() ==
2050 bufferization::MaterializeInDestinationOp::getOperationName()) {
2053 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2055 }
else if (getCopyBackOp() == kCopyOpNone) {
2058 llvm_unreachable(
"unsupported copy_back op");
2061 bool irChanged =
false;
2062 if (getUsePrescribedTensorShapes() &&
2063 linalgTarget.hasPureTensorSemantics()) {
2066 for (
OpOperand &operand : linalgTarget->getOpOperands()) {
2067 for (
auto [i, dim] :
llvm::enumerate(linalgTarget.getShape(&operand))) {
2068 if (!ShapedType::isDynamic(dim))
2070 options.setSizeToPadTo(operand.getOperandNumber(), i,
2072 operand.get().getLoc(),
2082 replacements, newPadOps))) {
2085 diag.attachNote(target->
getLoc()) <<
"target op";
2088 auto diag = emitSilenceableError() <<
"failed to pad op";
2089 diag.attachNote(target->
getLoc()) <<
"target op";
2098 rewriter.
replaceOp(linalgTarget, replacements);
2099 paddedOps.push_back(paddedOp);
2100 padOps.append(newPadOps.begin(), newPadOps.end());
2102 for (
Value v : replacements) {
2103 Operation *copyBackOp = v.getDefiningOp();
2104 if (!llvm::is_contained(copyBackOps, copyBackOp))
2105 copyBackOps.push_back(copyBackOp);
2110 results.
set(cast<OpResult>(getPadded()), paddedOps);
2111 results.
set(cast<OpResult>(getPad()), padOps);
2112 results.
set(cast<OpResult>(getCopy()), copyBackOps);
2118 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2119 if (any_of(nofoldFlags, [](int64_t packPadding) {
2120 return packPadding != 0 && packPadding != 1;
2122 return emitOpError()
2123 <<
"expects nofold_flags to contain booleans (0/1), found "
2124 << getNofoldFlags();
2128 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2129 if (any_of(paddingDimensions,
2130 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
2131 return emitOpError() <<
"expects padding_dimensions to contain positive "
2133 << getPaddingDimensions();
2135 if (!getMixedPadToMultipleOf().empty()) {
2136 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2137 return emitOpError() <<
"expects as many multiples as padding_dimensions";
2140 ArrayAttr transposes = getTransposePaddings();
2143 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2144 if (!std::is_permutation(sequence.begin(), sequence.end(),
2146 return emitOpError()
2147 <<
"expects transpose_paddings to be a permutation, found "
2151 if (getCopyBackOp() !=
2152 bufferization::MaterializeInDestinationOp::getOperationName() &&
2153 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2154 getCopyBackOp() != kCopyOpNone)
2155 return emitOpError() <<
"invalid copy_back_op";
2163 void transform::PadTilingInterfaceOp::build(
OpBuilder &b,
2168 bool padToMultipleOf) {
2179 : b.getDenseI64ArrayAttr(paddingSizes)),
2181 padToMultipleOf ? b.getUnitAttr() : nullptr);
2184 void transform::PadTilingInterfaceOp::build(
2192 staticPaddingSizes);
2199 dynamicPaddingSizes,
2204 void transform::PadTilingInterfaceOp::getEffects(
2213 transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2215 return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2224 for (
Operation *target : state.getPayloadOps(getTarget())) {
2225 auto targetOp = dyn_cast<TilingInterface>(target);
2227 auto diag = emitSilenceableError() <<
"expected TilingInterface target";
2228 diag.attachNote(target->
getLoc()) <<
"target op";
2235 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2236 auto diag = emitSilenceableError() <<
"only IndexingMapOpInterface ops "
2238 diag.attachNote(target->
getLoc()) <<
"target op";
2244 for (
auto const &[untypedAttr, elementOrTensorType] :
2245 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2246 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2249 emitOpError(
"expects padding values to be typed attributes");
2253 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2257 if (!parsedAttr || parsedAttr.getType() != elementType) {
2258 auto diag = this->emitOpError(
"expects a padding that parses to ")
2259 << elementType <<
", got " << attr;
2260 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2263 paddingValues.push_back(parsedAttr);
2267 if (attr.getType() != elementType) {
2268 auto diag = this->emitOpError(
"expects a padding value of type ")
2269 << elementType <<
", got " << attr;
2270 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2273 paddingValues.push_back(attr);
2277 TilingInterface paddedOp;
2279 options.setPaddingValues(paddingValues)
2280 .setPaddingDimensions(
2281 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()))
2282 .setPaddingSizes(getMixedPaddingSizes())
2283 .setPadToMultipleOf(getPadToMultipleOf());
2288 rewriter, cast<TilingInterface>(targetOp.getOperation()),
options,
2290 if (failed(maybePaddedOp)) {
2291 auto diag = emitSilenceableError() <<
"failed to pad op";
2292 diag.attachNote(target->
getLoc()) <<
"target op";
2297 paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
2298 padOps.append(newPadOps.begin(), newPadOps.end());
2301 results.
set(cast<OpResult>(getPadded()), paddedOps);
2302 results.
set(cast<OpResult>(getPad()), padOps);
2308 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2309 if (any_of(paddingDimensions,
2310 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
2311 return emitOpError() <<
"expects padding_dimensions to contain positive "
2313 << getPaddingDimensions();
2315 if (getMixedPaddingSizes().size() != paddingDimensions.size()) {
2316 return emitOpError() <<
"expects as many multiples as padding_dimensions";
2329 auto targetOps = state.getPayloadOps(getTarget());
2330 auto loopOps = state.getPayloadOps(getLoop());
2331 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2333 <<
"requires exactly one target and one loop handle (got "
2334 << llvm::range_size(targetOps) <<
" and "
2335 << llvm::range_size(loopOps) <<
")";
2338 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2339 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2340 if (!padOp || !loopOp)
2343 FailureOr<linalg::detail::PackingResult> result =
2349 if (result->clonedLoopIvs.empty()) {
2350 transformResults.
set(cast<OpResult>(getPackingLoop()),
2351 {result->hoistedPadOp.getOperation()});
2354 auto outerPackedLoop =
2356 transformResults.
set(cast<OpResult>(getPackingLoop()),
2357 {outerPackedLoop.getOperation()});
2363 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2364 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
2366 return emitOpError() <<
"expects transpose to be a permutation, found "
2372 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2382 tensor::PadOp target,
2385 tensor::PadOp hoistedPadOp;
2387 FailureOr<Value> result =
2389 hoistedPadOp, transposeOps);
2390 if (succeeded(result)) {
2400 return emitDefaultSilenceableFailure(target);
2405 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2406 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
2408 return emitOpError() <<
"expects transpose to be a permutation, found "
2424 if (!getOperandsToPromote().empty())
2426 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2427 if (getUseFullTilesByDefault())
2429 getUseFullTilesByDefault());
2431 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2432 if (!getUseFullTileBuffers().empty())
2434 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2435 if (getAlignment().has_value())
2436 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2437 if (getMemorySpace().has_value())
2438 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2440 if (getMapping().has_value()) {
2442 auto mapping = *getMapping();
2443 if (mapping.size() > 1)
2444 return emitDefaultDefiniteFailure(target);
2446 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2448 if (addressSpace.getAddressSpace() ==
2449 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2456 }
else if (addressSpace.getAddressSpace() ==
2457 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2465 return emitDefaultDefiniteFailure(target);
2470 return emitDefaultDefiniteFailure(target);
2473 FailureOr<LinalgOp> res =
promoteSubViews(rewriter, target, promotionOptions);
2475 return emitDefaultDefiniteFailure(target);
2488 auto payload = state.getPayloadOps(getTarget());
2492 if (target->getNumOperands() > 0)
2495 target->getNumRegions() > 0)
2497 <<
"expected target that is isolated from above";
2501 Operation *pattern = &getBodyRegion().front().front();
2504 if (getOperation()->isAncestor(target))
2509 replacements.push_back(replacement);
2511 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2515 void transform::ReplaceOp::getEffects(
2523 if (!getBodyRegion().hasOneBlock())
2524 return emitOpError() <<
"expected one block";
2525 if (std::distance(getBodyRegion().front().begin(),
2526 getBodyRegion().front().end()) != 1)
2527 return emitOpError() <<
"expected one operation in block";
2528 Operation *replacement = &getBodyRegion().front().front();
2531 <<
"expected replacement without operands";
2535 <<
"expect op that is isolated from above";
2553 target.createFlatListOfOperandDims(b, loc);
2554 AffineMap map = target.getShapesToLoopsMap();
2569 FailureOr<scf::SCFTilingResult> maybeTilingResult =
tileUsingSCF(
2570 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2571 if (failed(maybeTilingResult))
2572 return emitDefaultDefiniteFailure(target);
2574 if (target->getNumResults())
2575 rewriter.
replaceOp(target, maybeTilingResult->replacements);
2579 results.
reserve(maybeTilingResult->tiledOps.size());
2580 for (
Operation *tiled : maybeTilingResult->tiledOps)
2594 for (
Operation *target : state.getPayloadOps(getTarget())) {
2595 auto tilingOp = dyn_cast<TilingInterface>(*target);
2598 emitSilenceableError()
2599 <<
"expected the payload to implement TilingInterface";
2600 diag.attachNote(target->getLoc()) <<
"payload op";
2604 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2606 if (failed(generatedLoops))
2607 return emitDefaultDefiniteFailure(target);
2608 for (scf::ForOp &loop : *generatedLoops) {
2609 loops.push_back(loop.getOperation());
2613 results.
set(cast<OpResult>(getResult()), loops);
2622 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2627 FailureOr<Operation *> maybeResult =
2629 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2630 [&rewriter](
auto op) {
2633 if (failed(maybeResult))
2634 return emitDefaultSilenceableFailure(target);
2648 llvm::to_vector(state.getPayloadOps(getTarget()));
2650 bool isMultiwaySplit = getMultiway();
2652 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2654 <<
"requires exactly one target when "
2655 "multiway split is enabled (got "
2656 << llvm::range_size(payload) <<
")";
2661 if (!isMultiwaySplit)
2662 chunkSizes.reserve(payload.size());
2664 if (getDynamicChunkSizes()) {
2666 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2667 chunkSizes = llvm::to_vector(llvm::map_range(
2668 state.getPayloadOps(getDynamicChunkSizes()), [&](
Operation *op) {
2671 diag = emitSilenceableError()
2672 <<
"expected dynamic split point handle to point to a "
2673 "single-result index-typed op";
2674 diag.attachNote(op->getLoc()) <<
"dynamic split point";
2679 chunkSizes = llvm::to_vector(
2680 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2683 if (
diag.isSilenceableFailure())
2688 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2690 <<
"expected the dynamic split point handle to point to as "
2692 << chunkSizes.size() <<
") as the target handle ("
2693 << payload.size() <<
")";
2696 chunkSizes.resize(payload.size(),
2700 auto checkStructuredOpAndDimensions =
2703 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2704 diag.attachNote(loc) <<
"target op";
2708 if (getDimension() >= linalgOp.getNumLoops()) {
2709 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2710 <<
" does not exist in target op";
2711 diag.attachNote(loc) <<
"target op";
2717 auto checkFailureInSplitting =
2721 diag.attachNote(loc) <<
"target op";
2728 if (isMultiwaySplit) {
2731 TilingInterface head, tail;
2734 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2738 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2739 if (
diag.isSilenceableFailure())
2745 target = tail.getOperation();
2750 linalgOp = cast<LinalgOp>(target);
2755 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2756 getDimension(), chunkSize);
2760 checkFailureInSplitting(!head && !tail, loc);
2761 if (
diag.isDefiniteFailure())
2764 opList.push_back(head.getOperation());
2769 opList.push_back(tail.getOperation());
2775 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2778 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2780 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2782 if (
diag.isSilenceableFailure())
2786 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2787 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2788 getDimension(), std::get<1>(pair));
2792 checkFailureInSplitting(!first.back() && !second.back(), loc);
2797 if (!second.back()) {
2798 noSecondPart = target;
2803 if (second.size() != first.size() && !second.empty()) {
2804 auto diag = emitSilenceableError()
2805 <<
"splitting does not produce the second part for a subset "
2808 <<
"expected splitting to produce the second part of all "
2809 "or none of the targets";
2811 <<
"first target with no second part";
2815 opList.append(first);
2817 opList.append(second);
2819 results.
set(cast<OpResult>(getSplitList()), opList);
2823 void SplitOp::getEffects(
2826 if (getDynamicChunkSizes())
2834 IntegerAttr staticChunkSizes;
2840 if (!dynamicPointParseResult.
has_value()) {
2841 int64_t staticChunkSizesValue;
2842 if (failed(parser.
parseInteger(staticChunkSizesValue)))
2855 if (dynamicPointParseResult.
has_value()) {
2856 Type ChunkSizesType;
2857 if (failed(*dynamicPointParseResult) || parser.
parseComma() ||
2869 SplitOp::getStaticChunkSizesAttrName(result.
name).getValue(),
2876 printer <<
" " << getTarget() <<
" after ";
2877 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
2878 if (staticChunkSize != ShapedType::kDynamic)
2879 printer << staticChunkSize;
2881 printer << getDynamicChunkSizes();
2884 {getStaticChunkSizesAttrName()});
2885 printer <<
" : " << getTarget().getType();
2886 if (staticChunkSize == ShapedType::kDynamic)
2887 printer <<
", " << getDynamicChunkSizes().getType();
2891 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2892 (getDynamicChunkSizes() ==
nullptr)) {
2893 return emitOpError() <<
"expects either a dynamic or a static split "
2894 "point to be provided";
2903 void transform::SplitReductionOp::build(
2905 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
2906 bool useScalingAlgorithm,
bool useAlloc) {
2912 SplitReductionOp::getInsertSplitDimensionAttrName(result.
name),
2914 if (innerParallel) {
2915 result.
addAttribute(SplitReductionOp::getInnerParallelAttrName(result.
name),
2918 if (useScalingAlgorithm) {
2920 SplitReductionOp::getUseScalingAlgorithmAttrName(result.
name),
2928 result.
addTypes({resultType, resultType, resultType, resultType});
2937 unsigned(getInsertSplitDimension()),
2938 bool(getInnerParallel())};
2941 FailureOr<SplitReductionResult> splitResult =
2942 (getUseScalingAlgorithm())
2945 if (failed(splitResult))
2946 return emitDefaultDefiniteFailure(target);
2948 results.
push_back(splitResult->initOrAlloc);
2950 results.
push_back(splitResult->splitLinalgOp);
2951 results.
push_back(splitResult->resultCombiningLinalgOp);
2959 void transform::TileReductionUsingForOp::build(
2970 build(builder, result,
2973 staticTileSizesAttr);
2982 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2983 if (!partialReductionOp) {
2986 "Operation should implement PartialReductionOpInterface");
2989 rewriter, partialReductionOp,
2993 return emitDefaultSilenceableFailure(target);
2994 rewriter.
replaceOp(target, result->replacements);
2995 for (
Value initValue : result->initialValues)
2997 for (
auto parallelTiledOp : result->tiledOps)
2999 for (
auto mergeOp : result->mergeOps)
3001 results.
push_back(result->loops.front());
3009 void transform::TileReductionUsingForallOp::build(
3012 ArrayAttr mapping) {
3022 build(builder, result,
3025 staticNumThreadsAttr,
3026 staticTileSizesAttr,
3039 FailureOr<linalg::ForallReductionTilingResult> result =
3041 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
3042 numThreads, tileSizes, getMapping());
3044 if (failed(result)) {
3045 auto diag = emitSilenceableError() <<
"could not tile reduction";
3046 diag.attachNote(target.getLoc()) <<
"target operation";
3049 for (
Value initValue : result->initialValues)
3051 for (
auto parallelTiledOp : result->parallelTiledOps)
3053 for (
auto mergeOp : result->mergeOps)
3069 llvm::to_vector(state.getPayloadOps(getTarget()));
3071 if (!llvm::hasSingleElement(targetOps)) {
3073 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
3078 auto linalgOp = dyn_cast<LinalgOp>(target);
3079 auto tileableOp = dyn_cast<TilingInterface>(target);
3084 OpBuilder builder(linalgOp.getContext());
3086 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
3087 if (linalgOp.hasDynamicShape()) {
3088 auto diag = emitSilenceableError()
3089 <<
"cannot compute parametric tile sizes for dynamically "
3090 "shaped payload op";
3091 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
3095 FailureOr<StaticContinuousTileSizeSpecification> spec =
3099 return emitSilenceableError()
3100 <<
"failed to compute multi-size tiling sizes";
3105 for (
auto &&[tileSize, tripCount] :
3106 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3107 chunkSizes.push_back(tileSize * tripCount);
3110 return llvm::map_to_vector(values, [&](int64_t value) ->
Attribute {
3115 getI64AttrsFromI64(spec->tileSizes));
3116 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
3117 getI64AttrsFromI64(chunkSizes));
3125 unsigned dimension = getDimension();
3128 builder, tileableOp, dimension, targetSize,
true);
3130 return emitSilenceableError() <<
"could not generate tile size computation";
3142 for (
auto &&[tileSize, tripCount] :
3143 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3144 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3145 chunkSizes.push_back(splitPoint);
3149 return llvm::map_to_vector(values, [&](
Value value) ->
Operation * {
3155 getDefiningOps(spec->tileSizes));
3156 transformResults.
set(cast<OpResult>(getChunkSizes()),
3157 getDefiningOps(chunkSizes));
3165 return emitOpError() <<
"expects all results type to be the same";
3171 void transform::ContinuousTileSizesOp::getEffects(
3189 Type &tileSizesType,
3190 Type &chunkSizesType) {
3191 FunctionType funcType;
3193 if (failed(parser.
parseType<FunctionType>(funcType)))
3196 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3197 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
3198 "argument and one result";
3200 targetType = funcType.getInput(0);
3201 tileSizesType = chunkSizesType = funcType.getResult(0);
3210 void transform::TileUsingForOp::build(
3215 return build(builder, result, loopTypes,
3219 interchange, scalableSizes);
3222 void transform::TileUsingForOp::build(
3226 build(builder, result, target,
3228 interchange, scalableSizes);
3231 void transform::TileUsingForOp::build(
3238 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3242 void transform::TileUsingForOp::build(
3254 unsigned numExpectedLoops =
3255 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3257 resultTypes.reserve(numExpectedLoops);
3258 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3259 "expected one loop type or as many as loops");
3260 if (loopTypes.size() == 1)
3261 resultTypes.append(numExpectedLoops, loopTypes[0]);
3263 llvm::append_range(resultTypes, loopTypes);
3265 if (scalableSizes.has_value())
3266 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3267 build(builder, result, target.
getType(),
3271 staticTileSizesAttr,
3273 expandedScalableSizes);
3278 return emitOpError(
"expected same number of sizes (")
3280 << getScalableSizes().size() <<
")";
3282 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3283 if (getLoops().size() != numExpectedLoops)
3284 return emitOpError(
"expected number of loops to tile (")
3285 << numExpectedLoops <<
") to match number of `loops` results ("
3286 << getLoops().size() <<
")";
3297 llvm::to_vector(state.getPayloadOps(getTarget()));
3303 if (isa<ParamType>(transformValue.getType())) {
3304 dynamicSizeProducers.push_back({});
3306 paramSizes.push_back(
3307 llvm::to_vector(llvm::map_range(params, [](
Attribute attr) {
3308 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3311 if (paramSizes.back().size() != targets.size()) {
3313 emitSilenceableError()
3314 <<
"expected as many parameter values ("
3315 << dynamicSizeProducers.back().size() <<
") as target ops ("
3316 << targets.size() <<
")";
3317 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
3323 paramSizes.push_back({});
3324 dynamicSizeProducers.push_back(
3325 llvm::to_vector(state.getPayloadOps(transformValue)));
3327 if (dynamicSizeProducers.back().size() != targets.size()) {
3329 emitSilenceableError()
3330 <<
"expected as many dynamic size-producing operations ("
3331 << dynamicSizeProducers.back().size() <<
") as target ops ("
3332 << targets.size() <<
")";
3333 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3337 for (
Operation *op : dynamicSizeProducers.back()) {
3344 emitSilenceableError() <<
"expected sizes to be produced by ops "
3345 "with a single index-type result";
3346 diag.attachNote(op->
getLoc()) <<
"size producer op";
3347 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3354 loops.resize(getLoops().size());
3355 auto scalableSizes = getScalableSizes();
3357 auto tilingInterface = dyn_cast<TilingInterface>(op);
3358 if (!tilingInterface) {
3360 emitSilenceableError()
3361 <<
"only ops implementing TilingInterface are supported";
3362 diag.attachNote(op->
getLoc()) <<
"target op";
3365 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3367 emitSilenceableError()
3368 <<
"too many tiles provided, expected at most "
3369 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
3370 << tileSizes.size();
3371 diag.attachNote(op->
getLoc()) <<
"target op";
3376 if (tileSizes.empty()) {
3385 sizes.reserve(tileSizes.size());
3386 unsigned dynamicIdx = 0;
3389 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3390 if (scalableSizes[ofrIdx]) {
3391 auto val = b.
create<arith::ConstantIndexOp>(
3392 getLoc(), cast<IntegerAttr>(attr).getInt());
3398 sizes.push_back(attr);
3405 assert((dynamicSizes.empty() ^ params.empty()) &&
3406 "expected either dynamic sizes or parameters");
3407 if (!params.empty()) {
3410 sizes.push_back(dynamicSizes[index]->getResult(0));
3418 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3419 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3420 if (failed(maybeTilingResult))
3423 rewriter.
replaceOp(op, maybeTilingResult->replacements);
3425 tiled.append(maybeTilingResult->tiledOps);
3427 loops[en2.index()].push_back(en2.value());
3430 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3432 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3441 results.reserve(tileSizes.size());
3442 unsigned dynamicPos = 0;
3444 for (int64_t size : tileSizes) {
3445 if (size == ShapedType::kDynamic) {
3446 results.push_back(dynamic[dynamicPos++]);
3454 void transform::TileUsingForOp::getEffects(
3466 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3470 ArrayAttr mapping) {
3471 return build(builder, result,
3479 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3483 ArrayAttr mapping) {
3493 build(builder, result,
3494 TypeRange{operationType, operationType},
3501 staticTileSizesAttr,
3505 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3509 ArrayAttr mapping) {
3510 return build(builder, result, target,
3515 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3519 ArrayAttr mapping) {
3530 build(builder, result,
3531 TypeRange{operationType, operationType},
3537 staticNumThreadsAttr,
3550 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3552 for (
auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3554 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3555 normalizedUbs.push_back(normalizedUb);
3557 return normalizedUbs;
3573 for (
auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3576 denormalizedIvs.push_back(
3579 return denormalizedIvs;
3590 scf::ForallOp loop) {
3607 auto normalizedForallOp = rewriter.
create<scf::ForallOp>(
3608 loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3611 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3613 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3618 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3619 normalizedForallOp.getRegionIterArgs().end());
3620 Block *origLoopBlock = loop.getBody();
3621 rewriter.
mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3623 rewriter.
replaceOp(loop, normalizedForallOp);
3624 return normalizedForallOp;
3629 TransformOpInterface transformOp,
Operation *target,
3634 auto tileableOp = dyn_cast<TilingInterface>(target);
3637 transformOp.emitSilenceableError()
3638 <<
"only TilingInterface ops are supported";
3639 diag.attachNote(target->
getLoc()) <<
"target op";
3645 if (!mixedNumThreads.empty()) {
3646 options.setNumThreads(mixedNumThreads);
3648 options.setTileSizes(mixedTileSizes);
3651 options.setMapping(mapping.value().getValue());
3653 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3656 if (failed(maybeTilingResult))
3657 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3659 rewriter.
replaceOp(tileableOp, maybeTilingResult->replacements);
3661 tilingResult = *maybeTilingResult;
3663 if (mixedNumThreads.empty()) {
3664 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.
loops.front());
3667 scf::ForallOp normalizedForallOp =
3669 tilingResult.
loops.front() = normalizedForallOp;
3679 auto transformOp = cast<TransformOpInterface>(getOperation());
3688 getPackedNumThreads()
3690 state, transformOp, mixedNumThreads, getPackedNumThreads())
3692 state, transformOp, mixedNumThreads, getMixedNumThreads());
3696 status = getPackedTileSizes()
3698 state, transformOp, mixedTileSizes, getPackedTileSizes())
3700 state, transformOp, mixedTileSizes, getMixedTileSizes());
3704 for (
Operation *target : state.getPayloadOps(getTarget())) {
3707 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3708 getMapping(), tilingResult);
3709 if (!
diag.succeeded())
3711 tileOps.push_back(tilingResult.
loops.front());
3712 tiledOps.append(tilingResult.
tiledOps);
3715 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3716 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3721 void transform::TileUsingForallOp::getEffects(
3734 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3743 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3744 static_cast<int>(getPackedNumThreads() !=
Value());
3745 if (numThreadsSpec > 1)
3747 "num_threads and packed_num_threads are mutually exclusive");
3748 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3749 static_cast<int>(getPackedTileSizes() !=
Value());
3750 if (tileSizesSpec > 1)
3752 "tile_sizes and packed_tile_sizes are mutually exclusive");
3753 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3754 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
3755 "must be specified");
3763 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3765 bool vectorizePadding,
bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
3767 if (vectorizePadding) {
3769 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3773 if (vectorizeExtract) {
3775 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3779 if (flatten1DDepthwiseConv) {
3781 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3792 explicit VectorizationPattern(
MLIRContext *context,
3793 bool vectorizeExtract =
false,
3794 bool flattenConv =
false)
3796 vectorizeNDExtract(vectorizeExtract),
3797 flatten1DDepthwiseConv(flattenConv) {}
3798 LogicalResult matchAndRewrite(
Operation *op,
3802 "Unsupported Op, cannot vectorize");
3804 {}, vectorizeNDExtract,
3805 flatten1DDepthwiseConv);
3811 bool vectorizeNDExtract =
false;
3815 bool flatten1DDepthwiseConv =
false;
3820 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3825 auto diag = this->emitOpError(
"requires isolated-from-above targets");
3826 diag.attachNote(target->
getLoc()) <<
"non-isolated target";
3832 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3833 getFlatten_1dDepthwiseConv());
3835 if (!getDisableTransferPermutationMapLoweringPatterns())
3838 if (!getDisableMultiReductionToContractPatterns())
3846 vector::TransferReadOp::getCanonicalizationPatterns(
patterns, ctx);
3847 vector::TransferWriteOp::getCanonicalizationPatterns(
patterns, ctx);
3852 if (getVectorizePadding()) {
3864 return emitDefaultDefiniteFailure(target);
3878 auto targets = state.getPayloadOps(getTarget());
3879 if (std::empty(targets))
3881 auto transformOp = cast<TransformOpInterface>(getOperation());
3884 state, transformOp, getMixedVectorSizes(), vectorSizes);
3892 <<
"Unsupported Op, cannot vectorize";
3897 getVectorizeNdExtract().value_or(
false)))) {
3899 <<
"Attempted to vectorize, but failed";
3906 void transform::VectorizeOp::getEffects(
3915 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3919 if (getStaticVectorSizes().size() != getScalableSizes().size())
3920 return emitOpError(
"expected same number of vector sizes (")
3921 << getStaticVectorSizes().size() <<
") and scalable sizes ("
3922 << getScalableSizes().size() <<
")";
3931 transform::HoistRedundantVectorTransfersOp::applyToOne(
3948 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3967 auto maybeTransformed =
3970 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3973 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3976 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3979 .Case([&](linalg::Conv2DNchwFchwOp op) {
3985 if (failed(maybeTransformed))
3986 return emitDefaultSilenceableFailure(target);
3988 results.
push_back(maybeTransformed->first);
3990 results.
push_back(maybeTransformed->second);
4005 <<
"only elementwise flattening is supported";
4008 if (target.getNumLoops() <= 1) {
4015 std::iota(reassociation.begin(), reassociation.end(), 0);
4016 auto maybeFlattened =
4018 if (failed(maybeFlattened))
4020 <<
"attempted to flatten, but failed";
4021 results.
push_back(maybeFlattened->collapsedOp);
4022 rewriter.
replaceOp(target, maybeFlattened->results);
4035 auto maybeTransformed =
4037 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4040 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4046 if (failed(maybeTransformed))
4047 return emitDefaultSilenceableFailure(target);
4062 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4063 auto maybeTransformed =
4065 .Case([&](linalg::MatmulOp op) {
4068 .Case([&](linalg::BatchMatmulOp op) {
4071 .Default([&](
Operation *op) {
return failure(); });
4072 if (failed(maybeTransformed))
4082 template <
typename OpTy>
4086 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4087 tensor::ParallelInsertSliceOp>() &&
4090 if (
auto copySource =
4091 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4098 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
4100 target->template getParentOfType<scf::InParallelOp>());
4103 Value extracted = rewriter.
create<tensor::ExtractSliceOp>(
4104 target.getLoc(), target.getDest(), target.getMixedOffsets(),
4105 target.getMixedSizes(), target.getMixedStrides());
4106 Value copied = rewriter
4107 .
create<linalg::CopyOp>(target.getLoc(),
4108 target.getSource(), extracted)
4113 target, copied, target.getDest(), target.getMixedOffsets(),
4114 target.getMixedSizes(), target.getMixedStrides());
4116 results.
push_back(copied.getDefiningOp());
4126 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4127 return doit(rewriter, target, results, state);
4128 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4129 return doit(rewriter, target, results, state);
4132 emitSilenceableError()
4133 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4134 diag.attachNote(targetOp->
getLoc()) <<
"target op";
4147 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4149 emitSilenceableError()
4150 <<
"only linalg.copy and tensor.pad target ops are supported";
4151 diag.attachNote(target->
getLoc()) <<
"target op";
4154 assert(target->
getNumResults() == 1 &&
"expected single result");
4156 if (!resultShapedType.hasStaticShape()) {
4158 emitSilenceableError()
4159 <<
"only statically sized ops of rank <= 3 are supported";
4160 diag.attachNote(target->
getLoc()) <<
"target op";
4165 int64_t desiredBitAlignment = getDesiredBitAlignment();
4166 int64_t eltBitwidth =
4167 resultShapedType.getElementType().getIntOrFloatBitWidth();
4168 if (desiredBitAlignment % eltBitwidth != 0) {
4169 desiredBitAlignment = eltBitwidth;
4174 getTotalNumThreads(),
4175 desiredBitAlignment,
4176 resultShapedType.getShape(),
4179 resultShapedType.getElementType().getIntOrFloatBitWidth());
4180 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4182 emitSilenceableError()
4183 <<
"too few threads to map copy op to threads on the most minor "
4184 "dimension, given alignment and vector size constraints, try "
4185 "smaller tile size of mapping to more threads";
4186 diag.attachNote(target->
getLoc()) <<
"target op";
4202 if (!
diag.succeeded())
4206 for (
auto op : tilingResult.
tiledOps)
4220 FailureOr<Operation *> maybeTransformed = failure();
4222 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4227 .Default([&](
Operation *op) {
return false; });
4230 return emitSilenceableError()
4231 <<
"this operation is not supported to convert to Winograd Conv2D";
4234 if (failed(maybeTransformed)) {
4235 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
4247 FailureOr<Operation *> maybeTransformed = failure();
4250 .Case([&](linalg::WinogradFilterTransformOp op) {
4254 .Case([&](linalg::WinogradInputTransformOp op) {
4258 .Case([&](linalg::WinogradOutputTransformOp op) {
4262 .Default([&](
Operation *op) {
return false; });
4266 emitSilenceableError()
4267 <<
"this operation is not supported to decompose into other operations";
4268 diag.attachNote(target->
getLoc()) <<
"target op";
4272 if (failed(maybeTransformed)) {
4274 emitSilenceableError() <<
"decompose Winograd operations failed";
4275 diag.attachNote(target->
getLoc()) <<
"target op";
4283 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4285 #define GET_OP_CLASSES
4286 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
This class represents a saved insertion point.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getOpResult(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
uint64_t getM(LevelType lt)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
@ MaterializeInDestination
Transformation to drop unit-extent dimensions from linalg.generic operations.
Vectorization pattern for memref::CopyOp.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Match and rewrite for the pattern:
Match and rewrite for the pattern:
@ BufferizationMaterializeInDestination
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.