43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Debug.h"
47 #include <type_traits>
53 #define DEBUG_TYPE "linalg-transforms"
54 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55 #define DBGSNL() (llvm::dbgs() << "\n")
56 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
63 template <
typename PatternTy,
typename... Args>
66 using OpTy =
typename llvm::function_traits<
67 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
68 auto op = dyn_cast<OpTy>(operation);
73 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
78 explicit TrivialPatternRewriter(
MLIRContext *context)
81 TrivialPatternRewriter rewriter(operation->
getContext());
82 rewriter.setInsertionPoint(operation);
83 auto result = pattern.returningMatchAndRewrite(op, rewriter);
86 return cast<LinalgOp>(result->getOperation());
96 if (!isa<IntegerAttr>(ofr.get<
Attribute>()))
97 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
98 result.push_back(ofr);
101 auto payloadOps = state.getPayloadOps(ofr.get<
Value>());
102 if (!llvm::hasSingleElement(payloadOps)) {
104 transformOp.emitSilenceableError()
105 <<
"handle must be mapped to exactly one payload op";
107 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
114 transformOp.emitSilenceableError()
115 <<
"payload op must have exactly 1 index result";
134 for (
Operation *op : state.getPayloadOps(packedHandle)) {
137 transformOp.emitSilenceableError()
138 <<
"payload op must have exactly 1 index result";
153 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
158 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
164 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
167 options.rankReductionStrategy =
172 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
181 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
186 resultTypes.push_back(b.
getType<transform::AnyValueType>());
187 resultTypes.push_back(b.
getType<transform::AnyOpType>());
188 return build(b, result,
194 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
197 int64_t memorySpace) {
199 resultTypes.push_back(b.
getType<transform::AnyValueType>());
200 resultTypes.push_back(b.
getType<transform::AnyOpType>());
201 return build(b, result,
217 void notifyOperationInserted(
Operation *op)
override {
218 ForwardingListener::notifyOperationInserted(op);
219 auto inserted = newOps.insert(op);
221 assert(inserted.second &&
"expected newly created op");
224 void notifyOperationRemoved(
Operation *op)
override {
225 ForwardingListener::notifyOperationRemoved(op);
239 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
240 NewOpsListener newOpsListener(previousListener);
244 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
247 }
else if (getMemcpyOp() ==
"memref.copy") {
250 }
else if (getMemcpyOp() ==
"linalg.copy") {
254 llvm_unreachable(
"invalid memcpy op");
256 if (getAllocOp() ==
"memref.alloc") {
259 }
else if (getAllocOp() ==
"memref.alloca") {
263 llvm_unreachable(
"invalid alloc op");
265 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
266 options.emitDealloc = getEmitDealloc();
270 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
272 for (
Operation *op : state.getPayloadOps(getTarget())) {
277 <<
"failed to bufferize operation";
278 diag.attachNote(op->
getLoc()) <<
"target payload op";
281 allocatedBuffers.push_back(buffer);
285 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
286 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
290 void transform::BufferizeToAllocationOp::getEffects(
292 if (getBufferizeDestinationOnly()) {
305 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
306 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
307 return emitOpError() <<
"unsupported memcpy op";
308 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
309 return emitOpError() <<
"unsupported alloc op";
322 #define DOWNSCALE(trans) \
324 FailureOr<LinalgOp> res = tryApply<trans>(target); \
325 if (succeeded(res)) { \
326 results.push_back(*res); \
327 return DiagnosedSilenceableFailure::success(); \
331 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
332 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
345 #undef DOWNSCALE_NORMAL
346 #undef DOWNSCALE_CALL
348 return emitDefaultSilenceableFailure(target);
362 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
363 if (!decomposableOp) {
365 "payload is not a decomposable op"));
366 return emitDefaultSilenceableFailure(target);
370 decomposableOp.decomposeOperation(rewriter);
371 if (
failed(maybeNewResults))
372 return emitDefaultSilenceableFailure(target);
374 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
375 for (
Value val : *maybeNewResults) {
376 Operation *definition = val.getDefiningOp();
387 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
394 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
398 options.allowReturnAllocsFromLoops =
true;
400 for (
Operation *target : state.getPayloadOps(getTarget())) {
404 <<
"failed to analyze op";
406 rewriter, target, state)))
408 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
419 template <
typename Range>
429 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
430 if (!tilingInterfaceOp)
431 return transformOp->
emitError(
"only TilingInterface ops are supported");
435 applyFn(tilingInterfaceOp);
441 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
442 for (
Operation *toReplace : opsToReplace) {
443 for (
OpResult res : toReplace->getResults())
444 if (
auto replacement = tiledResults->replacements.lookup(res))
446 if (toReplace->use_empty()) {
452 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
453 assert(tiledResults->loops.size() == numLoops &&
454 "Mismatched number of loops, tile and fuse transform should have "
456 for (
unsigned int i = 0; i < numLoops; ++i)
457 loopOps[i].push_back(tiledResults->loops[i]);
460 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
461 for (
unsigned int i = 0; i < numLoops; ++i)
462 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
472 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
474 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
480 tilingOptions = tilingOptions.
setTileSizes(tileSizesOfr);
484 rewriter, getOperation(), state.getPayloadOps(getTarget()),
485 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
486 [&](TilingInterface tilingInterfaceOp)
488 return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
489 rewriter, tilingInterfaceOp, tileAndFuseOptions);
502 FunctionType trailingType;
508 if (trailingType.getNumInputs() != 1)
509 return parser.
emitError(typeLoc) <<
"expected one input type";
511 result.
addTypes(trailingType.getResults());
512 if (parser.
resolveOperand(targetOperand, trailingType.getInput(0),
524 getResults().getTypes());
529 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
530 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
531 if (!std::is_permutation(sequence.begin(), sequence.end(),
532 permutation.begin(), permutation.end())) {
533 return emitOpError() <<
"expects interchange to be a permutation, found "
534 << getTileInterchange();
538 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
539 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
540 if (numExpectedLoops != getNumResults() - 1)
541 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
550 void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
553 Value containingOp) {
556 result.
addTypes({resultType, resultType});
572 (domInfo.
dominates(containingOp, user))) {
573 dominatedUsers.insert(user);
576 if (dominatedUsers.empty())
580 auto forallOp = cast<scf::ForallOp>(containingOp);
586 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
591 newOuts.push_back(outputs[resultNumber]);
594 auto newforallOp = rewriter.
create<scf::ForallOp>(
595 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
596 forallOp.getMixedStep(), newOuts, forallOp.getMapping());
598 newforallOp.getRegion().takeBody(forallOp.getRegion());
603 newforallOp.getBody()->addArgument(newOuts.back().getType(),
604 newOuts.back().getLoc());
605 auto bbArgs = newforallOp.getBody()->getArguments();
613 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
615 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; }));
616 Operation *firstYieldOp = yieldingOps.front();
619 Value dst = newforallOp.getOutputBlockArguments().back();
621 rewriter.
create<tensor::ParallelInsertSliceOp>(firstYieldOp->
getLoc(), src,
622 dst, offsets, sizes, strides);
626 newforallOp->getResult(result.index()));
629 newforallOp->getResults().back(),
632 return dominatedUsers.contains(user);
643 static std::tuple<SmallVector<Operation *>,
Operation *>
646 LLVM_DEBUG(
DBGS() <<
"Try to fuse a direct extract use\n");
647 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
648 if (!tileableProducer) {
650 <<
"producer is not a TileableInterface: " << *producerOp;
657 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
658 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
659 return sliceOp && containingOp->isProperAncestor(sliceOp);
663 if (it == tileableProducer->getUsers().end()) {
664 diag.attachNote(tileableProducer->getLoc())
665 <<
"could not find fusion opportunity for: " << *tileableProducer;
668 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
675 int64_t resultNumber =
676 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
677 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
683 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
686 if (
failed(tileAndFuseResult)) {
687 diag.attachNote(tileableProducer->getLoc())
688 <<
"failed to tile producer op: " << *tileableProducer;
693 for (
auto tiledOp : tileAndFuseResult->tiledOps) {
694 LLVM_DEBUG(
DBGS() <<
"tiledProducer: " << *tiledOp <<
"\n");
699 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
700 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
701 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
702 if (
failed(maybeRankReduced)) {
704 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
705 << tileAndFuseResult->tiledValues[0]
706 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
709 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
713 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
714 resultNumber, offsets, sizes);
716 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
729 LLVM_DEBUG(
DBGS() <<
"Try to fuse an extract use through block argument\n");
731 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
732 if (!tileableProducer) {
734 <<
"producer is not a TileableInterface: " << *producerOp;
739 scf::ForallOp forallOp;
740 auto itProducerUses =
741 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
742 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
746 if (!forallOp || forallOp != containingOp) {
747 diag.attachNote(tileableProducer->getLoc())
748 <<
"could not find a use by the containing op: " << *tileableProducer;
763 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
764 return sliceOp && containingOp->isProperAncestor(sliceOp);
768 if (itBBArgUsers == bbArg.
getUsers().end()) {
770 <<
"could not find fusion opportunity for bbArg: " << bbArg;
773 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
781 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
782 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
787 rewriter, tileableProducer->getLoc(), tileableProducer,
788 destinationTensors))) {
789 diag.attachNote(tileableProducer->getLoc())
790 <<
"failed to get destination tensors for: " << *tileableProducer;
795 bvm.
map(destinationTensors[resultNumber], bbArg);
796 auto tileableProducerClone =
797 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
799 llvm::make_scope_exit([&]() { rewriter.
eraseOp(tileableProducerClone); });
803 tileableProducerClone.generateResultTileValue(
804 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
805 sliceOpToTile.getMixedSizes());
806 if (
failed(tileAndFuseResult)) {
807 diag.attachNote(tileableProducer->getLoc())
808 <<
"failed to tile producer op: " << *tileableProducer;
813 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
814 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
815 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
816 assert(
succeeded(maybeRankReduced) &&
"unexpected shape");
817 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
822 destinationTensors.front());
825 return tileAndFuseResult->tiledOps;
831 LLVM_DEBUG(
DBGS() <<
"Try to fuse an use by cloning\n");
836 for (
OpOperand &use : result.getUses()) {
838 uses.push_back(&use);
843 if (containingOp == use.getOwner()) {
845 <<
"producer op use by containing op cannot be fused by cloning";
853 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
862 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
863 "Parallel insert slice is not a valid clone destination");
864 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
865 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
869 fusedOp = rewriter.
clone(*producerOp);
871 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
876 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
886 auto producerOps = state.getPayloadOps(getProducerOp());
887 auto containingOps = state.getPayloadOps(getContainingOp());
888 if (!llvm::hasSingleElement(containingOps)) {
890 <<
"requires exactly one containing_op handle (got "
891 << llvm::range_size(containingOps) <<
")";
893 Operation *containingOp = *containingOps.begin();
896 if (std::empty(producerOps)) {
898 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
907 for (
const auto &it :
enumerate(remainingProducers)) {
910 int64_t numUsesInContainingOp =
912 return containingOp->isAncestor(op);
917 if (numUsesInContainingOp > 0) {
918 if (numUsesInContainingOp == 1)
919 remainingProducers.erase(remainingProducers.begin() + it.index());
926 while (!remainingProducers.empty()) {
927 auto nextProducer = getNextProducer();
928 if (
failed(nextProducer)) {
930 <<
"could not find next producer to fuse into container";
931 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
939 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
946 auto [tiledOps, newContainingOp] =
948 if (!tiledOps.empty()) {
949 LLVM_DEBUG(
DBGS() <<
"\nFused a direct extract use\n" << *containingOp);
950 fusedOps.append(tiledOps);
951 if (newContainingOp) {
962 (void)replacementStatus;
964 "unable to update transform state mapping");
965 rewriter.
eraseOp(containingOp);
966 containingOp = newContainingOp;
973 rewriter,
diag, producerOp, containingOp);
974 if (!tiledContainingOpOperand.empty()) {
975 LLVM_DEBUG(
DBGS() <<
"\nFused an extract use through block argument\n"
977 fusedOps.append(tiledContainingOpOperand);
984 LLVM_DEBUG(
DBGS() <<
"\nFused an use by cloning\n" << *containingOp);
985 fusedOps.push_back(cloned);
991 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
992 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
996 void transform::FuseIntoContainingOp::getEffects(
1014 if (isa<GenericOp>(target)) {
1021 results.
push_back(generic->getOperation());
1024 return emitDefaultSilenceableFailure(target);
1037 if (!isa<GenericOp>(target)) {
1045 results.
push_back(named->getOperation());
1048 return emitDefaultSilenceableFailure(target);
1062 if (interchangeVector.empty()) {
1067 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1068 if (interchangeVector.size() != numLoops) {
1069 return emitSilenceableError()
1070 << getIteratorInterchangeAttrName() <<
" has length ("
1071 << interchangeVector.size()
1072 <<
") different from the number of loops in the target operation ("
1078 interchangeVector.end()));
1087 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1088 if (!std::is_permutation(sequence.begin(), sequence.end(),
1089 permutation.begin(), permutation.end())) {
1090 return emitOpError()
1091 <<
"expects iterator_interchange to be a permutation, found "
1092 << getIteratorInterchange();
1109 <<
"cannot lower to pad + expand + transpose";
1112 transformResults.
push_back(res->expandShapeOp);
1113 transformResults.
push_back(res->transposeOp);
1129 <<
"cannot rewrite to pad + expand + transpose";
1131 transformResults.
push_back(res->emptyOp);
1132 transformResults.
push_back(res->transposeOp);
1133 transformResults.
push_back(res->collapseShapeOp);
1134 transformResults.
push_back(res->extractSliceOp);
1164 if (getOps().has_value())
1165 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1166 getOps()->getAsValueRange<StringAttr>().end());
1168 auto payloadOps = state.getPayloadOps(getTarget());
1169 if (!llvm::hasSingleElement(payloadOps)) {
1180 if (getInterface().has_value()) {
1181 auto iface = getInterface().value();
1182 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1185 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1186 !isa<TilingInterface>(op))
1188 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1189 !isa<LoopLikeOpInterface>(op))
1194 if (getOpAttrs().has_value()) {
1195 DictionaryAttr opAttrs = getOpAttrs().value();
1197 if (attr.getName() == getInterfaceAttrName() ||
1198 attr.getName() == getOpsAttrName())
1200 if (!op->
hasAttr(attr.getName()))
1202 if (op->
getAttr(attr.getName()) != attr.getValue())
1207 if (getFilterResultType().has_value()) {
1208 Type t = getFilterResultType().value();
1218 (*payloadOps.begin())->
walk(matchFun);
1219 results.
set(cast<OpResult>(getResult()), res);
1234 Type &targetType,
Type &lowSizeType,
1236 Type &splitPointType) {
1237 FunctionType funcType;
1242 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1243 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1244 "argument and one result";
1246 targetType = funcType.getInput(0);
1247 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1255 if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1256 if (target.hasDynamicShape()) {
1257 auto diag = emitSilenceableError()
1258 <<
"cannot compute parametric tile sizes for dynamically "
1259 "shaped payload op";
1260 diag.attachNote(target->getLoc()) <<
"payload op";
1265 target, getDimension(), getTargetSize(), getDivisor());
1267 return emitSilenceableError()
1268 <<
"failed to compute multi-size tiling sizes";
1271 Builder builder(target.getContext());
1272 results.
assign(llvm::map_range(
1274 spec->lowTileSize * spec->lowTripCount}),
1275 [&builder,
this](int64_t value) {
1277 cast<ParamType>(getLowSize().getType()).getType(), value);
1287 builder, target, getDimension(), targetSize, divisor);
1289 return emitSilenceableError() <<
"could not generate tile size computation";
1296 {spec->lowTileSize, spec->lowTripCount});
1297 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1298 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1299 assert(lowTileSize && highTileSize && splitPoint &&
1300 "tile sizes are not produced by operations");
1308 void transform::MultiTileSizesOp::getEffects(
1312 if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1319 if (getLowSize().getType() != getHighSize().getType() ||
1320 getLowSize().getType() != getSplitPoint().getType()) {
1321 return emitOpError() <<
"expects all results type to be the same";
1341 builder.
getContext(), GenericOp::getOperationName());
1342 build(builder, result,
1351 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1358 auto targetOps = state.getPayloadOps(getTarget());
1360 if (std::empty(targetOps)) {
1361 transformResults.
set(cast<OpResult>(getPackedOp()),
1366 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1367 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1368 return emitSilenceableError()
1369 <<
"requires target to map to exactly 1 LinalgOp (got "
1370 << llvm::range_size(targetOps) <<
")";
1373 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1374 return emitSilenceableError()
1375 <<
"requires number of packed sizes match the number of loops ("
1376 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1383 state, *
this, packedSizes, getMixedPackedSizes());
1390 transformResults.
set(cast<OpResult>(getPackedOp()),
1391 {maybeResult->packedLinalgOp.getOperation()});
1395 void transform::PackOp::getEffects(
1409 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1410 <<
" is not a valid permutation";
1413 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1414 for (
auto [s, nmo] :
1415 llvm::zip_equal(getMixedMatmulPackedSizes(),
1416 getMatmulPaddedSizesNextMultipleOf())) {
1419 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1420 return emitOpError() <<
"at most one of the packed_size and the "
1421 "padded_sizes_next_multiple_of can be nonzero "
1422 "for the matmul strategy";
1434 for (
Operation *op : state.getPayloadOps(getTarget())) {
1435 auto linalgOp = dyn_cast<LinalgOp>(op);
1446 getMixedMatmulPackedSizes(),
1448 getMatmulPaddedSizesNextMultipleOf(),
1449 getMatmulInnerDimsOrder());
1451 results.push_back(packResult->packedLinalgOp);
1454 results.push_back(linalgOp);
1456 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1462 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1466 void transform::PackGreedilyOp::getEffects(
1480 return emitOpError() << getInnerPermAttrName()
1481 <<
" is not a valid permutation";
1484 return emitOpError() << getOuterPermAttrName()
1485 <<
" is not a valid permutation";
1487 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1488 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1489 <<
" or " << getOuterPermAttrName()
1490 <<
" must be specified";
1496 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1506 template <
typename RelayoutOpTy>
1509 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1511 llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
1512 "applies to only pack or unpack operations");
1513 if (!op || permutation.empty())
1515 size_t innerRank = op.getInnerDimsPos().size();
1516 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1520 if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
1521 return permutation.size() == op.getSourceRank() &&
1524 return permutation.size() == op.getDestRank() &&
1532 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1533 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1535 if (std::empty(packOrUnpackOps)) {
1536 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1537 transformResults.
set(cast<OpResult>(getPackOp()), {});
1538 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1544 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1545 !llvm::hasSingleElement(linalgOps)) {
1546 return emitSilenceableError()
1547 <<
"requires target to map to exactly 1 "
1548 "packing op and 1 packed op ("
1549 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
1550 << llvm::range_size(linalgOps) <<
")";
1554 auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
1555 auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
1556 if ((!packOp && !unPackOp)) {
1557 return emitSilenceableError() <<
"requires target to map to a "
1558 "tensor.pack or tensor.unpack";
1560 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1561 if (!linalgOpTarget)
1562 return emitSilenceableError() <<
"requires a LinalgOp target";
1566 if (packOp && packOp.getResult().hasOneUse())
1567 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1569 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1570 if (linalgOp != linalgOpTarget) {
1572 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
1573 : StringLiteral{
"not produced by the LinalgOp target"};
1574 return emitSilenceableError() << errorMsg;
1580 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
1581 OpOperand *packUse = linalgOp.getDpsInitOperand(
1582 cast<OpResult>(unPackOp.getSource()).getResultNumber());
1584 if (!packOp || !packOp.getResult().hasOneUse())
1585 return emitSilenceableError() <<
"could not find matching pack op";
1589 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1591 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1592 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1593 ? StringLiteral{
"invalid outer_perm"}
1594 : StringLiteral{
"invalid inner_perm"};
1598 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1599 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
1605 assert(packOp && linalgOp &&
"unexpected null op");
1609 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1611 assert(
succeeded(res) &&
"unexpected packTranspose failure");
1614 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1615 transformResults.
set(cast<OpResult>(getPackedOp()),
1616 {res->transposedLinalgOp});
1618 transformResults.
set(cast<OpResult>(getUnPackOp()),
1619 {res->transposedUnPackOp});
1621 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1636 StringRef copyBackOp) {
1645 (padToMultipleOf.empty() ? ArrayAttr()
1646 : b.getI64ArrayAttr(padToMultipleOf)),
1647 b.getI64ArrayAttr(packPaddings),
1648 b.getArrayAttr(transposePaddings),
1649 b.getStringAttr(copyBackOp));
1658 for (
Operation *target : state.getPayloadOps(getTarget())) {
1659 auto linalgTarget = dyn_cast<LinalgOp>(target);
1660 if (!linalgTarget) {
1661 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
1662 diag.attachNote(target->
getLoc()) <<
"target op";
1668 for (int64_t packPadding :
1669 extractFromIntegerArrayAttr<int64_t>(getPackPaddings()))
1670 packPaddings.push_back(
static_cast<bool>(packPadding));
1674 for (
auto const &it :
1675 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1676 auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1678 emitOpError(
"expects padding values to be typed attributes");
1683 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
1687 if (!parsedAttr || parsedAttr.getType() != elementType) {
1688 auto diag = this->emitOpError(
"expects a padding that parses to ")
1689 << elementType <<
", got " << std::get<0>(it);
1690 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
1693 paddingValues.push_back(parsedAttr);
1697 if (attr.getType() != elementType) {
1698 auto diag = this->emitOpError(
"expects a padding value of type ")
1699 << elementType <<
", got " << attr;
1700 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
1703 paddingValues.push_back(attr);
1708 for (
Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1709 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1710 cast<ArrayAttr>(transposeVector)));
1715 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1717 if (getPadToMultipleOf().has_value())
1719 extractFromIntegerArrayAttr<int64_t>(*getPadToMultipleOf());
1720 options.padToMultipleOf = padToMultipleOf;
1721 options.paddingValues = paddingValues;
1722 options.packPaddings = packPaddings;
1723 if (getCopyBackOp() ==
1724 bufferization::MaterializeInDestinationOp::getOperationName()) {
1727 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1729 }
else if (getCopyBackOp() == kCopyOpNone) {
1732 llvm_unreachable(
"unsupported copy_back op");
1738 replacements, newPadOps))) {
1739 auto diag = emitSilenceableError() <<
"failed to pad op";
1740 diag.attachNote(target->
getLoc()) <<
"target op";
1749 rewriter.
replaceOp(linalgTarget, replacements);
1750 paddedOps.push_back(paddedOp);
1751 padOps.append(newPadOps.begin(), newPadOps.end());
1753 for (
Value v : replacements) {
1754 Operation *copyBackOp = v.getDefiningOp();
1755 if (llvm::find(copyBackOps, copyBackOp) == copyBackOps.end())
1756 copyBackOps.push_back(copyBackOp);
1761 results.
set(cast<OpResult>(getPadded()), paddedOps);
1762 results.
set(cast<OpResult>(getPad()), padOps);
1763 results.
set(cast<OpResult>(getCopy()), copyBackOps);
1769 extractFromIntegerArrayAttr<int64_t>(getPackPaddings());
1770 if (any_of(packPaddings, [](int64_t packPadding) {
1771 return packPadding != 0 && packPadding != 1;
1773 return emitOpError()
1774 <<
"expects pack_paddings to contain booleans (0/1), found "
1775 << getPackPaddings();
1779 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1780 if (any_of(paddingDimensions,
1781 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
1782 return emitOpError() <<
"expects padding_dimensions to contain positive "
1784 << getPaddingDimensions();
1786 if (getPadToMultipleOf().has_value()) {
1787 if (getPadToMultipleOf()->size() != paddingDimensions.size()) {
1788 return emitOpError() <<
"expects as many multiples as padding_dimensions";
1791 ArrayAttr transposes = getTransposePaddings();
1794 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1795 if (!std::is_permutation(sequence.begin(), sequence.end(),
1796 transpose.begin(), transpose.end())) {
1797 return emitOpError()
1798 <<
"expects transpose_paddings to be a permutation, found "
1802 if (getCopyBackOp() !=
1803 bufferization::MaterializeInDestinationOp::getOperationName() &&
1804 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1805 getCopyBackOp() != kCopyOpNone)
1806 return emitOpError() <<
"invalid copy_back_op";
1818 auto targetOps = state.getPayloadOps(getTarget());
1819 auto loopOps = state.getPayloadOps(getLoop());
1820 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1822 <<
"requires exactly one target and one loop handle (got "
1823 << llvm::range_size(targetOps) <<
" and "
1824 << llvm::range_size(loopOps) <<
")";
1827 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1828 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
1829 if (!padOp || !loopOp)
1838 if (result->clonedLoopIvs.empty()) {
1839 transformResults.
set(cast<OpResult>(getPackingLoop()),
1840 {result->hoistedPadOp.getOperation()});
1843 auto outerPackedLoop =
1845 transformResults.
set(cast<OpResult>(getPackingLoop()),
1846 {outerPackedLoop.getOperation()});
1852 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1853 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
1855 return emitOpError() <<
"expects transpose to be a permutation, found "
1861 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
1871 tensor::PadOp target,
1874 tensor::PadOp hoistedPadOp;
1878 hoistedPadOp, transposeOps);
1889 return emitDefaultSilenceableFailure(target);
1894 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1895 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
1897 return emitOpError() <<
"expects transpose to be a permutation, found "
1913 if (!getOperandsToPromote().empty())
1915 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
1916 if (getUseFullTilesByDefault())
1918 getUseFullTilesByDefault());
1920 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
1921 if (!getUseFullTileBuffers().empty())
1923 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
1924 if (getAlignment().has_value())
1925 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
1926 if (getMemorySpace().has_value())
1927 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
1929 if (getMapping().has_value()) {
1931 auto mapping = *getMapping();
1932 if (mapping.size() > 1)
1933 return emitDefaultDefiniteFailure(target);
1935 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
1937 if (addressSpace.getAddressSpace() ==
1938 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
1945 }
else if (addressSpace.getAddressSpace() ==
1946 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
1954 return emitDefaultDefiniteFailure(target);
1959 return emitDefaultDefiniteFailure(target);
1964 return emitDefaultDefiniteFailure(target);
1977 auto payload = state.getPayloadOps(getTarget());
1981 if (target->getNumOperands() > 0)
1984 target->getNumRegions() > 0)
1986 <<
"expected target that is isolated from above";
1990 Operation *pattern = &getBodyRegion().front().front();
1993 if (getOperation()->isAncestor(target))
1998 replacements.push_back(replacement);
2000 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2004 void transform::ReplaceOp::getEffects(
2012 if (!getBodyRegion().hasOneBlock())
2013 return emitOpError() <<
"expected one block";
2014 if (std::distance(getBodyRegion().front().begin(),
2015 getBodyRegion().front().end()) != 1)
2016 return emitOpError() <<
"expected one operation in block";
2017 Operation *replacement = &getBodyRegion().front().front();
2020 <<
"expected replacement without operands";
2024 <<
"expect op that is isolated from above";
2042 target.createFlatListOfOperandDims(b, loc);
2043 AffineMap map = target.getShapesToLoopsMap();
2060 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2061 if (
failed(maybeTilingResult))
2062 return emitDefaultDefiniteFailure(target);
2064 if (target->getNumResults())
2065 rewriter.
replaceOp(target, maybeTilingResult->replacements);
2069 results.
reserve(maybeTilingResult->tiledOps.size());
2070 for (
Operation *tiled : maybeTilingResult->tiledOps)
2080 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2088 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2089 [&rewriter](
auto op) {
2093 return emitDefaultSilenceableFailure(target);
2107 llvm::to_vector(state.getPayloadOps(getTarget()));
2109 splitPoints.reserve(payload.size());
2110 if (getDynamicSplitPoint()) {
2112 if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
2113 splitPoints = llvm::to_vector(llvm::map_range(
2114 state.getPayloadOps(getDynamicSplitPoint()), [&](
Operation *op) {
2117 diag = emitSilenceableError()
2118 <<
"expected dynamic split point handle to point to a "
2119 "single-result index-typed op";
2120 diag.attachNote(op->getLoc()) <<
"dynamic split point";
2125 splitPoints = llvm::to_vector(
2126 llvm::map_range(state.getParams(getDynamicSplitPoint()),
2129 if (
diag.isSilenceableFailure())
2132 if (splitPoints.size() != payload.size()) {
2134 <<
"expected the dynamic split point handle to point to as "
2136 << splitPoints.size() <<
") as the target handle ("
2137 << payload.size() <<
")";
2140 splitPoints.resize(payload.size(),
2147 for (
const auto &pair : llvm::zip(payload, splitPoints)) {
2149 auto linalgOp = dyn_cast<LinalgOp>(target);
2151 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2152 diag.attachNote(target->
getLoc()) <<
"target op";
2156 if (getDimension() >= linalgOp.getNumLoops()) {
2157 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2158 <<
" does not exist in target op";
2159 diag.attachNote(target->
getLoc()) <<
"target op";
2164 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2165 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2166 getDimension(), std::get<1>(pair));
2169 if (!first.back() && !second.back()) {
2171 diag.attachNote(target->
getLoc()) <<
"target op";
2176 if (!second.back()) {
2177 noSecondPart = target;
2182 if (second.size() != first.size() && !second.empty()) {
2183 auto diag = emitSilenceableError()
2184 <<
"splitting does not produce the second part for a subset "
2186 diag.attachNote() <<
"expected splitting to produce the second part of all "
2187 "or none of the targets";
2189 <<
"first target with no second part";
2193 results.
set(cast<OpResult>(getFirst()), first);
2194 results.
set(cast<OpResult>(getSecond()), second);
2198 void SplitOp::getEffects(
2201 if (getDynamicSplitPoint())
2209 IntegerAttr staticSplitPoint;
2215 if (!dynamicPointParseResult.
has_value()) {
2216 int64_t staticSplitPointValue;
2230 if (dynamicPointParseResult.
has_value()) {
2231 Type splitPointType;
2244 SplitOp::getStaticSplitPointAttrName(result.
name).getValue(),
2246 result.
addTypes({targetType, targetType});
2251 printer <<
" " << getTarget() <<
" after ";
2252 int64_t staticSplitSize =
static_cast<int64_t
>(getStaticSplitPoint());
2253 if (staticSplitSize != ShapedType::kDynamic)
2254 printer << staticSplitSize;
2256 printer << getDynamicSplitPoint();
2259 {getStaticSplitPointAttrName()});
2260 printer <<
" : " << getTarget().getType();
2261 if (staticSplitSize == ShapedType::kDynamic)
2262 printer <<
", " << getDynamicSplitPoint().getType();
2266 if ((
static_cast<int64_t
>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
2267 (getDynamicSplitPoint() ==
nullptr)) {
2268 return emitOpError() <<
"expects either a dynamic or a static split "
2269 "point to be provided";
2278 void transform::SplitReductionOp::build(
2280 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
2281 bool useScalingAlgorithm,
bool useAlloc) {
2287 SplitReductionOp::getInsertSplitDimensionAttrName(result.
name),
2289 if (innerParallel) {
2290 result.
addAttribute(SplitReductionOp::getInnerParallelAttrName(result.
name),
2293 if (useScalingAlgorithm) {
2295 SplitReductionOp::getUseScalingAlgorithmAttrName(result.
name),
2303 result.
addTypes({resultType, resultType, resultType, resultType});
2312 unsigned(getInsertSplitDimension()),
2313 bool(getInnerParallel())};
2317 (getUseScalingAlgorithm())
2321 return emitDefaultDefiniteFailure(target);
2323 results.
push_back(splitResult->initOrAlloc);
2325 results.
push_back(splitResult->splitLinalgOp);
2326 results.
push_back(splitResult->resultCombiningLinalgOp);
2334 void transform::TileReductionUsingForOp::build(
2345 build(builder, result,
2348 staticTileSizesAttr);
2357 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2361 return emitDefaultSilenceableFailure(target);
2363 results.
push_back(result->parallelTiledOp);
2365 results.
push_back(result->loops.front());
2373 void transform::TileReductionUsingForallOp::build(
2376 ArrayAttr mapping) {
2386 build(builder, result,
2389 staticNumThreadsAttr,
2390 staticTileSizesAttr,
2405 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2406 numThreads, tileSizes, getMapping());
2409 auto diag = emitSilenceableError() <<
"could not tile reduction";
2410 diag.attachNote(target.getLoc()) <<
"target operation";
2414 results.
push_back(result->parallelTiledOp);
2424 void transform::TileUsingForOp::build(
2429 return build(builder, result, loopTypes,
2433 interchange, scalableSizes);
2436 void transform::TileUsingForOp::build(
2440 build(builder, result, target,
2442 interchange, scalableSizes);
2445 void transform::TileUsingForOp::build(
2452 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2456 void transform::TileUsingForOp::build(
2468 unsigned numExpectedLoops =
2469 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2471 resultTypes.reserve(numExpectedLoops);
2472 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2473 "expected one loop type or as many as loops");
2474 if (loopTypes.size() == 1)
2475 resultTypes.append(numExpectedLoops, loopTypes[0]);
2477 llvm::append_range(resultTypes, loopTypes);
2479 if (scalableSizes.has_value())
2480 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2481 build(builder, result, target.
getType(),
2485 staticTileSizesAttr,
2487 expandedScalableSizes);
2492 return emitOpError(
"expected same number of sizes (")
2494 << getScalableSizes().size() <<
")";
2505 llvm::to_vector(state.getPayloadOps(getTarget()));
2511 if (isa<ParamType>(transformValue.getType())) {
2512 dynamicSizeProducers.push_back({});
2514 paramSizes.push_back(
2515 llvm::to_vector(llvm::map_range(params, [](
Attribute attr) {
2516 return cast<IntegerAttr>(attr).getValue().getSExtValue();
2519 if (paramSizes.back().size() != targets.size()) {
2521 emitSilenceableError()
2522 <<
"expected as many parameter values ("
2523 << dynamicSizeProducers.back().size() <<
") as target ops ("
2524 << targets.size() <<
")";
2525 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
2531 paramSizes.push_back({});
2532 dynamicSizeProducers.push_back(
2533 llvm::to_vector(state.getPayloadOps(transformValue)));
2535 if (dynamicSizeProducers.back().size() != targets.size()) {
2537 emitSilenceableError()
2538 <<
"expected as many dynamic size-producing operations ("
2539 << dynamicSizeProducers.back().size() <<
") as target ops ("
2540 << targets.size() <<
")";
2541 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
2545 for (
Operation *op : dynamicSizeProducers.back()) {
2552 emitSilenceableError() <<
"expected sizes to be produced by ops "
2553 "with a single index-type result";
2554 diag.attachNote(op->
getLoc()) <<
"size producer op";
2555 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
2562 loops.resize(getLoops().size());
2563 auto scalableSizes = getScalableSizes();
2565 auto tilingInterface = dyn_cast<TilingInterface>(op);
2566 if (!tilingInterface) {
2568 emitSilenceableError()
2569 <<
"only ops implementing TilingInterface are supported";
2570 diag.attachNote(op->
getLoc()) <<
"target op";
2573 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
2575 emitSilenceableError()
2576 <<
"too many tiles provided, expected at most "
2577 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
2578 << tileSizes.size();
2579 diag.attachNote(op->
getLoc()) <<
"target op";
2584 if (!tileSizes.empty()) {
2588 sizes.reserve(tileSizes.size());
2589 unsigned dynamicIdx = 0;
2592 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
2593 if (scalableSizes[ofrIdx]) {
2594 auto val = b.
create<arith::ConstantIndexOp>(
2595 getLoc(), attr.
cast<IntegerAttr>().getInt());
2601 sizes.push_back(attr);
2608 assert((dynamicSizes.empty() ^ params.empty()) &&
2609 "expected either dynamic sizes or parameters");
2610 if (!params.empty()) {
2613 sizes.push_back(dynamicSizes[index]->getResult(0));
2623 if (
failed(maybeTilingResult))
2626 rewriter.
replaceOp(op, maybeTilingResult->replacements);
2628 tiled.append(maybeTilingResult->tiledOps);
2630 loops[en2.index()].push_back(en2.value());
2633 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
2635 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
2644 results.reserve(tileSizes.size());
2645 unsigned dynamicPos = 0;
2647 for (int64_t size : tileSizes) {
2648 if (size == ShapedType::kDynamic) {
2649 results.push_back(dynamic[dynamicPos++]);
2676 if (!interchangeVals.empty()) {
2677 p <<
" {interchange = [";
2678 llvm::interleaveComma(interchangeVals, p,
2679 [&](int64_t integer) { p << integer; });
2689 FunctionType functionalType;
2690 llvm::SMLoc operandLoc;
2699 size_t numExpectedLoops =
2700 staticSizes.size() - llvm::count(staticSizes.
asArrayRef(), 0);
2701 if (functionalType.getNumResults() != numExpectedLoops + 1) {
2703 <<
"expected " << (numExpectedLoops + 1) <<
" result type(s)";
2705 if (functionalType.getNumInputs() != dynamicSizes.size() + 1) {
2707 <<
"expected " << dynamicSizes.size() + 1 <<
" operand type(s)";
2709 if (parser.
resolveOperand(target, functionalType.getInputs().front(),
2712 functionalType.getInputs().drop_front(),
2720 result.
addTypes(functionalType.getResults());
2725 p <<
' ' << getTarget();
2727 {}, getScalableSizesAttr(),
2734 void transform::TileUsingForOp::getEffects(
2747 void transform::TileUsingForallOp::build(
OpBuilder &builder,
2751 ArrayAttr mapping) {
2752 return build(builder, result,
2760 void transform::TileUsingForallOp::build(
OpBuilder &builder,
2764 ArrayAttr mapping) {
2774 build(builder, result,
2775 TypeRange{operationType, operationType},
2782 staticTileSizesAttr,
2786 void transform::TileUsingForallOp::build(
OpBuilder &builder,
2790 ArrayAttr mapping) {
2791 return build(builder, result, target,
2796 void transform::TileUsingForallOp::build(
OpBuilder &builder,
2800 ArrayAttr mapping) {
2811 build(builder, result,
2812 TypeRange{operationType, operationType},
2818 staticNumThreadsAttr,
2825 TransformOpInterface transformOp,
Operation *target,
2830 auto tileableOp = dyn_cast<TilingInterface>(target);
2833 transformOp.emitSilenceableError()
2834 <<
"only TilingInterface ops are supported";
2835 diag.attachNote(target->
getLoc()) <<
"target op";
2840 if (!mixedNumThreads.empty()) {
2845 rewriter, tileableOp, mixedTileSizes, mapping);
2848 if (
failed(maybeTilingResult))
2849 return transformOp.emitDefaultSilenceableFailure(tileableOp);
2850 rewriter.
replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
2852 tilingResult = *maybeTilingResult;
2860 auto transformOp = cast<TransformOpInterface>(getOperation());
2869 getPackedNumThreads()
2871 state, transformOp, mixedNumThreads, getPackedNumThreads())
2873 state, transformOp, mixedNumThreads, getMixedNumThreads());
2877 status = getPackedTileSizes()
2879 state, transformOp, mixedTileSizes, getPackedTileSizes())
2881 state, transformOp, mixedTileSizes, getMixedTileSizes());
2885 for (
Operation *target : state.getPayloadOps(getTarget())) {
2888 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
2889 getMapping(), tilingResult);
2890 if (!
diag.succeeded())
2892 tileOps.push_back(tilingResult.
tileOp);
2893 tiledOps.push_back(tilingResult.
tiledOp);
2896 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
2897 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
2902 void transform::TileUsingForallOp::getEffects(
2915 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
2924 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
2925 static_cast<int>(getPackedNumThreads() !=
Value());
2926 if (numThreadsSpec > 1)
2928 "num_threads and packed_num_threads are mutually exclusive");
2929 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
2930 static_cast<int>(getPackedTileSizes() !=
Value());
2931 if (tileSizesSpec > 1)
2933 "tile_sizes and packed_tile_sizes are mutually exclusive");
2934 if (numThreadsSpec == 0 && tileSizesSpec == 0)
2935 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
2936 "must be specified");
2944 void transform::VectorizeChildrenAndApplyPatternsOp::build(
2946 bool vectorizePadding,
bool vectorizeExtract) {
2948 if (vectorizePadding) {
2950 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
2954 if (vectorizeExtract) {
2956 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
2967 explicit VectorizationPattern(
MLIRContext *context,
2968 bool vectorizeExtract =
false)
2970 vectorizeNDExtract(vectorizeExtract) {}
2973 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
2976 return vectorize(rewriter, linalgOp, {},
2977 {}, vectorizeNDExtract);
2983 bool vectorizeNDExtract =
false;
2988 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
2993 auto diag = this->emitOpError(
"requires isolated-from-above targets");
2994 diag.attachNote(target->
getLoc()) <<
"non-isolated target";
3000 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract());
3002 if (!getDisableTransferPermutationMapLoweringPatterns())
3005 if (!getDisableMultiReductionToContractPatterns())
3013 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3014 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3019 if (getVectorizePadding())
3026 return emitDefaultDefiniteFailure(target);
3039 auto targets = state.getPayloadOps(getTarget());
3040 if (std::empty(targets))
3047 vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
3051 auto szPayloads = state.getPayloadOps(sz.get<
Value>());
3052 if (!llvm::hasSingleElement(szPayloads)) {
3053 auto diag = this->emitOpError(
3054 "requires vector size handle that is mapped to 1 payload op");
3056 <<
"mapped to " << llvm::range_size(szPayloads) <<
" payload ops";
3060 Operation *szPayloadOp = *szPayloads.begin();
3063 auto diag = this->emitOpError(
3064 "requires vector size payload op with 1 index result");
3065 diag.attachNote(szPayloadOp->
getLoc()) <<
"vector size payload op";
3071 auto diag = this->emitOpError(
"requires constant vector size");
3072 diag.attachNote(szPayloadOp->
getLoc()) <<
"vector size payload op";
3076 vectorSizes.push_back(attr.getInt());
3081 if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
3083 <<
"Unsupported Op, cannot vectorize";
3088 getVectorizeNdExtract().has_value()
3089 ? getVectorizeNdExtract().value()
3092 <<
"Attempted to vectorize, but failed";
3099 void transform::VectorizeOp::getEffects(
3108 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3112 if (getStaticVectorSizes().size() != getScalableSizes().size())
3113 return emitOpError(
"expected same number of vector sizes (")
3114 << getStaticVectorSizes().size() <<
") and scalable sizes ("
3115 << getScalableSizes().size() <<
")";
3124 transform::HoistRedundantVectorTransfersOp::applyToOne(
3145 auto maybeTransformed =
3148 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3151 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3154 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3157 .Case([&](linalg::Conv2DNchwFchwOp op) {
3163 if (
failed(maybeTransformed))
3164 return emitDefaultSilenceableFailure(target);
3166 results.
push_back(maybeTransformed->first);
3168 results.
push_back(maybeTransformed->second);
3181 auto maybeTransformed =
3183 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3186 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3192 if (
failed(maybeTransformed))
3193 return emitDefaultSilenceableFailure(target);
3202 template <
typename OpTy>
3206 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3207 tensor::ParallelInsertSliceOp>() &&
3210 if (
auto copySource =
3211 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3218 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3220 target->template getParentOfType<scf::InParallelOp>());
3223 Value extracted = rewriter.
create<tensor::ExtractSliceOp>(
3224 target.getLoc(), target.getDest(), target.getMixedOffsets(),
3225 target.getMixedSizes(), target.getMixedStrides());
3226 Value copied = rewriter
3227 .
create<linalg::CopyOp>(target.getLoc(),
3228 target.getSource(), extracted)
3233 target, copied, target.getDest(), target.getMixedOffsets(),
3234 target.getMixedSizes(), target.getMixedStrides());
3236 results.
push_back(copied.getDefiningOp());
3246 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3247 return doit(rewriter, target, results, state);
3248 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3249 return doit(rewriter, target, results, state);
3252 emitSilenceableError()
3253 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3254 diag.attachNote(targetOp->
getLoc()) <<
"target op";
3267 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3269 emitSilenceableError()
3270 <<
"only linalg.copy and tensor.pad target ops are supported";
3271 diag.attachNote(target->
getLoc()) <<
"target op";
3274 assert(target->
getNumResults() == 1 &&
"expected single result");
3276 if (!resultShapedType.hasStaticShape()) {
3278 emitSilenceableError()
3279 <<
"only statically sized ops of rank <= 3 are supported";
3280 diag.attachNote(target->
getLoc()) <<
"target op";
3285 int64_t desiredBitAlignment = getDesiredBitAlignment();
3286 int64_t eltBitwidth =
3287 resultShapedType.getElementType().getIntOrFloatBitWidth();
3288 if (desiredBitAlignment % eltBitwidth != 0) {
3289 desiredBitAlignment = eltBitwidth;
3294 getTotalNumThreads(),
3295 desiredBitAlignment,
3296 resultShapedType.getShape(),
3299 resultShapedType.getElementType().getIntOrFloatBitWidth());
3300 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3302 emitSilenceableError()
3303 <<
"too few threads to map copy op to threads on the most minor "
3304 "dimension, given alignment and vector size constraints, try "
3305 "smaller tile size of mapping to more threads";
3306 diag.attachNote(target->
getLoc()) <<
"target op";
3322 if (!
diag.succeeded())
3330 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3332 #define GET_OP_CLASSES
3333 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static FailureOr< ForallTilingResult > tileToForallOpImpl(RewriterBase &b, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayRef< OpFoldResult >> nominalTileSizes, std::optional< ArrayAttr > mapping, bool omitTileOffsetBoundsCheck)
Rewrite a TilingInterface op to a tiled scf.forall.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class provides support for representing a failure result, or a valid value of type T.
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
State for analysis-enabled bufferization.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
Operation * getOwner() const
Return the owner of this operand.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
FailureOr< ForallTilingResult > tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, std::optional< ArrayAttr > mapping)
Same as tileToForallOp, but calculate the number of threads required using the given tileSizes.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< ForallTilingResult > tileToForallOp(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayAttr > mapping)
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< GenericOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
void hoistRedundantVectorTransfers(func::FuncOp func)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false)
Emit a suitable vector form for an operation.
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< SCFTilingResult > tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, TypeRange valueTypes=TypeRange(), ArrayRef< bool > scalables={}, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
static LogicalResult failure(bool isFailure=true)
If isFailure is true a failure result is generated, otherwise a 'success' result is generated.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
@ MaterializeInDestination
Transformation to drop unit-extent dimensions from linalg.generic operations.
Vectorization pattern for memref::CopyOp.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Rewrite a TilingInterface op to a tiled scf.forall, applying tiling by numThreads.
Match and rewrite for the pattern:
Match and rewrite for the pattern:
@ BufferizationMaterializeInDestination
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > ts)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.