43#include "llvm/ADT/STLExtras.h"
44#include "llvm/ADT/ScopeExit.h"
45#include "llvm/ADT/SmallPtrSet.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/DebugLog.h"
48#include "llvm/Support/LogicalResult.h"
55#define DEBUG_TYPE "linalg-transforms"
62template <
typename PatternTy,
typename... Args>
65 using OpTy =
typename llvm::function_traits<
66 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
67 auto op = dyn_cast<OpTy>(operation);
72 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
77 auto result = pattern.returningMatchAndRewrite(op, rewriter);
80 return cast<LinalgOp>(
result->getOperation());
90 if (
auto attr = dyn_cast<Attribute>(ofr)) {
91 if (!isa<IntegerAttr>(attr))
92 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
97 Value transformValue = cast<Value>(ofr);
98 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
100 if (params.size() != 1)
101 return transformOp.emitDefiniteFailure()
102 <<
"requires exactly one parameter associated";
103 result.push_back(params[0]);
108 if (!llvm::hasSingleElement(payloadOps)) {
110 transformOp.emitSilenceableError()
111 <<
"handle must be mapped to exactly one payload op";
113 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
120 transformOp.emitSilenceableError()
121 <<
"payload op must have exactly 1 index result";
141 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
143 for (
auto param : params) {
144 if (!isa<IntegerAttr>(param))
145 return transformOp.emitDefiniteFailure()
146 <<
"expected the parameter to be associated with an integer "
154 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
156 transformOp.emitSilenceableError()
157 <<
"payload op must have exactly 1 index result";
158 diag.attachNote(op->getLoc())
159 <<
"has " << op->getNumResults() <<
" results";
162 result.push_back(op->getResult(0));
176 if (
auto attr = dyn_cast<Attribute>(paramOrHandle)) {
177 reified.push_back(cast<IntegerAttr>(attr).getInt());
179 }
else if (isa<ParamType>(cast<Value>(paramOrHandle).
getType())) {
181 if (params.size() != 1)
182 return transformOp.emitSilenceableError() <<
"expected a single param";
184 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
188 Value handle = cast<Value>(paramOrHandle);
189 if (!isa<TransformHandleTypeInterface>(handle.getType()))
190 return transformOp.emitSilenceableError() <<
"unexpected value handle";
192 if (!llvm::hasSingleElement(payload))
193 return transformOp.emitSilenceableError()
194 <<
"requires param or handle that is mapped to 1 payload op";
196 Operation *paramOrHandlePayloadOp = *payload.begin();
199 return transformOp.emitSilenceableError()
200 <<
"requires param or handle to be result of op with 1 index "
206 return transformOp.emitSilenceableError()
207 <<
"requires param or handle to be the result of a constant like "
210 reified.push_back(attr.getInt());
219void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
224void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
229void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
234void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
240void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
243 options.rankReductionStrategy =
248void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
253void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
258void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
263void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
268void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
282 SmallVector<Operation *> getNewOps()
const {
283 return SmallVector<Operation *>(newOps.begin(), newOps.end());
287 void notifyOperationInserted(Operation *op,
288 OpBuilder::InsertPoint previous)
override {
289 ForwardingListener::notifyOperationInserted(op, previous);
291 if (previous.
isSet())
295 assert(
inserted.second &&
"expected newly created op");
298 void notifyOperationErased(Operation *op)
override {
299 ForwardingListener::notifyOperationErased(op);
300 op->
walk([&](Operation *op) { newOps.erase(op); });
313 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
314 NewOpsListener newOpsListener(previousListener);
318 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
319 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
320 MaterializeInDestination;
321 }
else if (getMemcpyOp() ==
"memref.copy") {
324 }
else if (getMemcpyOp() ==
"linalg.copy") {
328 llvm_unreachable(
"invalid memcpy op");
330 if (getAllocOp() ==
"memref.alloc") {
333 }
else if (getAllocOp() ==
"memref.alloca") {
337 llvm_unreachable(
"invalid alloc op");
339 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
340 options.emitDealloc = getEmitDealloc();
344 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
351 <<
"failed to bufferize operation";
352 diag.attachNote(op->
getLoc()) <<
"target payload op";
355 allocatedBuffers.push_back(buffer);
359 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
360 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
364void transform::BufferizeToAllocationOp::getEffects(
366 if (getBufferizeDestinationOnly()) {
377LogicalResult transform::BufferizeToAllocationOp::verify() {
378 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
379 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
381 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
394 auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.
getOwner());
401 Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
409 if (!isa<TensorType, FloatType, IntegerType>(value.
getType()))
411 return llvm::any_of(value.
getUses(),
421 auto type = dyn_cast<RankedTensorType>(
tensor.getType());
423 return emitSilenceableError() <<
"non-tensor type: " <<
tensor;
437 for (
auto [pos, dim] : llvm::enumerate(type.getShape())) {
438 if (!ShapedType::isDynamic(dim))
443 tensor::DimOp::create(rewriter,
tensor.getLoc(),
tensor, cst);
444 preservedOps.insert(dimOp);
445 dynamicDims.push_back(dimOp);
447 auto allocation = bufferization::AllocTensorOp::create(
448 rewriter,
tensor.getLoc(), type, dynamicDims);
450 if (getMemorySpaceAttr())
451 allocation.setMemorySpaceAttr(getMemorySpaceAttr());
452 Value allocated = allocation;
456 if (needsMaterialization) {
457 auto copy = bufferization::MaterializeInDestinationOp::create(
459 preservedOps.insert(
copy);
460 promoted.push_back(
copy.getResult());
462 promoted.push_back(allocated);
466 results.
setValues(cast<OpResult>(getPromoted()), promoted);
470void transform::PromoteTensorOp::getEffects(
486#define DOWNSCALE(trans) \
488 FailureOr<LinalgOp> res = tryApply<trans>(target); \
489 if (succeeded(res)) { \
490 results.push_back(*res); \
491 return DiagnosedSilenceableFailure::success(); \
495#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
496#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
509#undef DOWNSCALE_NORMAL
512 return emitDefaultSilenceableFailure(
target);
526 auto decomposableOp = dyn_cast<AggregatedOpInterface>(
target);
527 if (!decomposableOp) {
529 "payload is not a decomposable op"));
530 return emitDefaultSilenceableFailure(
target);
533 FailureOr<SmallVector<Value>> maybeNewResults =
534 decomposableOp.decomposeOperation(rewriter);
535 if (
failed(maybeNewResults))
536 return emitDefaultSilenceableFailure(
target);
538 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
539 for (
Value val : *maybeNewResults) {
540 Operation *definition = val.getDefiningOp();
551void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
558transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
562 options.allowReturnAllocsFromLoops =
true;
568 <<
"failed to analyze op";
570 rewriter,
target, state)))
572 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
585 bool applyCleanup,
bool useForall) {
587 builder,
result, loopTypes,
593 applyCleanup, useForall);
599 bool applyCleanup,
bool useForall) {
607 applyCleanup, useForall);
614 bool applyCleanup,
bool useForall) {
618 build(builder,
result, loopTypes,
target, mixedTileSizes,
619 mixedTileInterchange, applyCleanup, useForall);
626 bool applyCleanup,
bool useForall) {
633 staticTileInterchange);
638 auto staticTileInterchangeAttr =
640 unsigned numExpectedLoops =
641 useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
643 resultTypes.reserve(numExpectedLoops);
644 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
645 "expected one loop type or as many as loops");
646 if (loopTypes.size() == 1)
647 resultTypes.append(numExpectedLoops, loopTypes[0]);
649 llvm::append_range(resultTypes, loopTypes);
654 dynamicTileInterchange,
656 staticTileInterchangeAttr,
663template <
typename Range>
667 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
673 auto tilingInterfaceOp = dyn_cast<TilingInterface>(
target);
674 if (!tilingInterfaceOp)
675 return transformOp->
emitError(
"only TilingInterface ops are supported");
678 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
679 applyFn(tilingInterfaceOp);
680 if (failed(tiledResults))
685 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
686 for (
Operation *toReplace : opsToReplace) {
687 for (
OpResult res : toReplace->getResults())
688 if (
auto replacement = tiledResults->replacements.lookup(res))
690 if (toReplace->use_empty()) {
696 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
697 assert(tiledResults->loops.size() == numLoops &&
698 "Mismatched number of loops, tile and fuse transform should have "
700 for (
unsigned int i = 0; i < numLoops; ++i)
701 loopOps[i].
push_back(tiledResults->loops[i]);
704 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
705 for (
unsigned int i = 0; i < numLoops; ++i)
706 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
715 auto transformOp = cast<TransformOpInterface>(getOperation());
719 state, transformOp, getMixedTileSizes(), tileSizes);
724 state, transformOp, getMixedTileInterchange(), tileInterchange);
728 scf::SCFTilingOptions tilingOptions;
729 tilingOptions.interchangeVector = tileInterchange;
730 bool useForall = getUseForall();
731 tilingOptions.setLoopType(useForall
732 ? scf::SCFTilingOptions::LoopType::ForallOp
733 : scf::SCFTilingOptions::LoopType::ForOp);
736 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
737 scf::SCFTileAndFuseOptions tileAndFuseOptions;
738 tileAndFuseOptions.tilingOptions = tilingOptions;
740 if (getApplyCleanup()) {
743 tensor::ExtractSliceOp::getCanonicalizationPatterns(
patterns, context);
746 tileAndFuseOptions.cleanupPatterns = std::move(
patterns);
750 useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
752 rewriter, getOperation(), state.
getPayloadOps(getTarget()), numLoops,
754 [&](TilingInterface tilingInterfaceOp)
755 -> FailureOr<scf::SCFTileAndFuseResult> {
756 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
763LogicalResult transform::FuseOp::verify() {
764 auto iterspace_rank = getStaticTileSizes().size();
766 if (permutation.size() > iterspace_rank)
768 <<
"interchange length exceeds iteration space dimensions ("
769 << iterspace_rank <<
"), found " << getTileInterchange();
771 for (
int64_t v : permutation) {
772 if (!ShapedType::isDynamic(v)) {
773 if (v < 0 || v >=
static_cast<int64_t>(iterspace_rank))
774 return emitOpError() <<
"expects interchange values to be in range [0, "
775 << iterspace_rank <<
"), found: " << v;
777 return emitOpError() <<
"found duplicate interchange value: " << v;
783 size_t numExpectedLoops =
784 getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
785 if (numExpectedLoops != getNumResults() - 1)
786 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
796 return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
800void transform::FuseOp::getEffects(
813void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
816 Value containingOp) {
817 result.addOperands({producerOp, containingOp});
818 auto resultType = transform::AnyOpType::get(builder.
getContext());
819 result.addTypes({resultType, resultType});
835 (domInfo.
dominates(containingOp, user))) {
836 dominatedUsers.insert(user);
839 if (dominatedUsers.empty())
843 auto forallOp = cast<scf::ForallOp>(containingOp);
849 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
854 newOuts.push_back(outputs[resultNumber]);
857 auto newforallOp = scf::ForallOp::create(
858 rewriter, loc, forallOp.getMixedLowerBound(),
859 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
860 forallOp.getMapping());
862 newforallOp.getRegion().takeBody(forallOp.getRegion());
867 newforallOp.getBody()->addArgument(newOuts.back().getType(),
868 newOuts.back().getLoc());
869 auto bbArgs = newforallOp.getBody()->getArguments();
872 Operation *op = use.getOwner();
873 return newforallOp->isProperAncestor(op);
877 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
879 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; }));
880 Operation *firstYieldOp = yieldingOps.front();
883 Value dst = newforallOp.getRegionIterArgs().back();
885 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->
getLoc(), src,
886 dst, offsets, sizes, strides);
888 for (
auto result : llvm::enumerate(forallOp.getResults())) {
890 newforallOp->getResult(
result.index()));
893 newforallOp->getResults().back(),
895 Operation *user = use.getOwner();
896 return dominatedUsers.contains(user);
910 destWorklist.push_back(dst);
912 while (!destWorklist.empty()) {
913 Value currentDst = destWorklist.pop_back_val();
917 if (src == currentDst)
922 auto bbArg = dyn_cast<BlockArgument>(currentDst);
926 Block *parentBlock = bbArg.getOwner();
927 assert(parentBlock &&
"unlinked block argument");
930 assert(parentOp &&
"expected block argument with parent operation");
933 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
937 for (
auto innerIterArg : parentLoop.getRegionIterArgs()) {
939 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
940 Value loopBlockArgument =
942 destWorklist.push_back(loopBlockArgument);
955static std::tuple<SmallVector<Operation *>,
Operation *>
958 LDBG() <<
"Try to fuse a direct extract use";
959 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
960 if (!tileableProducer) {
962 <<
"producer is not a TileableInterface: " << *producerOp;
969 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
970 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
971 return sliceOp && containingOp->isProperAncestor(sliceOp);
975 if (it == tileableProducer->getUsers().end()) {
976 diag.attachNote(tileableProducer->getLoc())
977 <<
"could not find fusion opportunity for: " << *tileableProducer;
980 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
993 if (LoopLikeOpInterface containerLoop =
994 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
1000 auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(
clone);
1004 for (
OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
1005 Value producerOperand =
1006 clone->getOperand(initOperandPtr.getOperandNumber());
1008 containerLoop.getRegionIterArgs()) {
1009 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
1010 Value consumerOperand =
1014 initOperandPtr.set(containerIterArg);
1020 tileableProducer = dyn_cast<TilingInterface>(
clone);
1025 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
1026 LDBG() <<
"resultNumber: " << resultNumber;
1031 FailureOr<TilingResult> tileAndFuseResult =
1032 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
1035 if (failed(tileAndFuseResult)) {
1036 diag.attachNote(tileableProducer->getLoc())
1037 <<
"failed to tile producer op: " << *tileableProducer;
1042 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
1043 LDBG() <<
"tiledProducer: " << *tiledOp;
1048 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1049 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1050 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1051 if (failed(maybeRankReduced)) {
1053 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
1054 << tileAndFuseResult->tiledValues[0]
1055 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
1058 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
1062 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
1063 resultNumber, offsets, sizes);
1066 if (isa<LoopLikeOpInterface>(containingOp))
1067 rewriter.
eraseOp(tileableProducer);
1069 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
1082 LDBG() <<
"Try to fuse an extract use through block argument";
1084 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
1085 if (!tileableProducer) {
1087 <<
"producer is not a TileableInterface: " << *producerOp;
1092 scf::ForallOp forallOp;
1093 auto itProducerUses =
1094 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
1095 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
1099 if (!forallOp || forallOp != containingOp) {
1100 diag.attachNote(tileableProducer->getLoc())
1101 <<
"could not find a use by the containing op: " << *tileableProducer;
1116 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1117 return sliceOp && containingOp->isProperAncestor(sliceOp);
1121 if (itBBArgUsers == bbArg.
getUsers().end()) {
1123 <<
"could not find fusion opportunity for bbArg: " << bbArg;
1126 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
1134 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
1135 LDBG() <<
"resultNumber: " << resultNumber;
1140 rewriter, tileableProducer->getLoc(), tileableProducer,
1141 destinationTensors))) {
1142 diag.attachNote(tileableProducer->getLoc())
1143 <<
"failed to get destination tensors for: " << *tileableProducer;
1148 bvm.
map(destinationTensors[resultNumber], bbArg);
1149 auto tileableProducerClone =
1150 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
1152 llvm::make_scope_exit([&]() { rewriter.
eraseOp(tileableProducerClone); });
1155 FailureOr<TilingResult> tileAndFuseResult =
1156 tileableProducerClone.generateResultTileValue(
1157 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
1158 sliceOpToTile.getMixedSizes());
1159 if (failed(tileAndFuseResult)) {
1160 diag.attachNote(tileableProducer->getLoc())
1161 <<
"failed to tile producer op: " << *tileableProducer;
1166 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1167 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1168 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1169 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
1170 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
1175 destinationTensors.front());
1178 return tileAndFuseResult->tiledOps;
1184 LDBG() <<
"Try to fuse an use by cloning";
1191 uses.push_back(&use);
1196 if (containingOp == use.getOwner()) {
1198 <<
"producer op use by containing op cannot be fused by cloning";
1206 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
1215 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
1216 "Parallel insert slice is not a valid clone destination");
1217 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
1218 LDBG() <<
"resultNumber: " << resultNumber;
1222 fusedOp = rewriter.
clone(*producerOp);
1224 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1229bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1240 auto containingOps = state.
getPayloadOps(getContainingOp());
1241 if (!llvm::hasSingleElement(containingOps)) {
1243 <<
"requires exactly one containing_op handle (got "
1244 << llvm::range_size(containingOps) <<
")";
1246 Operation *containingOp = *containingOps.begin();
1249 if (std::empty(producerOps)) {
1251 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1258 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1259 for (
const auto &it :
enumerate(remainingProducers)) {
1262 int64_t numUsesInContainingOp =
1264 return containingOp->isAncestor(op);
1269 if (numUsesInContainingOp > 0) {
1270 if (numUsesInContainingOp == 1)
1271 remainingProducers.erase(remainingProducers.begin() + it.index());
1278 while (!remainingProducers.empty()) {
1279 auto nextProducer = getNextProducer();
1280 if (
failed(nextProducer)) {
1282 <<
"could not find next producer to fuse into container";
1283 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
1291 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
1298 auto [tiledOps, newContainingOp] =
1300 if (!tiledOps.empty()) {
1301 LDBG() <<
"\nFused a direct extract use\n" << *containingOp;
1302 fusedOps.append(tiledOps);
1303 if (newContainingOp) {
1311 LogicalResult replacementStatus =
1314 (
void)replacementStatus;
1315 assert(succeeded(replacementStatus) &&
1316 "unable to update transform state mapping");
1317 rewriter.
eraseOp(containingOp);
1318 containingOp = newContainingOp;
1325 rewriter,
diag, producerOp, containingOp);
1326 if (!tiledContainingOpOperand.empty()) {
1327 LDBG() <<
"\nFused an extract use through block argument\n"
1329 fusedOps.append(tiledContainingOpOperand);
1336 LDBG() <<
"\nFused an use by cloning\n" << *containingOp;
1337 fusedOps.push_back(cloned);
1343 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1344 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1348void transform::FuseIntoContainingOp::getEffects(
1366 if (isa<GenericOp>(
target)) {
1372 if (succeeded(generic)) {
1373 results.
push_back(generic->getOperation());
1376 return emitDefaultSilenceableFailure(
target);
1389 if (!isa<GenericOp>(
target)) {
1394 FailureOr<LinalgOp> named =
1396 if (succeeded(named)) {
1397 results.
push_back(named->getOperation());
1400 return emitDefaultSilenceableFailure(
target);
1414 if (interchangeVector.empty()) {
1419 unsigned numLoops = cast<LinalgOp>(
target.getOperation()).getNumLoops();
1420 if (interchangeVector.size() != numLoops) {
1421 return emitSilenceableError()
1422 << getIteratorInterchangeAttrName() <<
" has length ("
1423 << interchangeVector.size()
1424 <<
") different from the number of loops in the target operation ("
1435LogicalResult transform::InterchangeOp::verify() {
1437 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1438 if (!std::is_permutation(sequence.begin(), sequence.end(),
1439 permutation.begin(), permutation.end())) {
1441 <<
"expects iterator_interchange to be a permutation, found "
1442 << getIteratorInterchange();
1457 if (!isa<linalg::CopyOp>(targetOp)) {
1459 emitSilenceableError() <<
"only linalg.copy target ops are supported";
1460 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1464 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1465 if (!copyOp.hasPureBufferSemantics()) {
1467 emitSilenceableError()
1468 <<
"cannot transform a linalg.copy on tensors into a memref.copy";
1469 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1475 assert(inputs.size() == 1 &&
"expected linalg copy op with one input");
1476 assert(outputs.size() == 1 &&
"expected memref copy op with one output");
1477 Value input = inputs.front();
1478 Value output = outputs.front();
1483 if (!isa<ShapedType>(input.
getType())) {
1485 emitSilenceableError()
1486 <<
"cannot transform a linalg.copy which input has no shape";
1487 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1492 assert(isa<ShapedType>(output.
getType()));
1494 if (cast<ShapedType>(input.
getType()).getElementType() !=
1495 cast<ShapedType>(output.
getType()).getElementType()) {
1497 emitSilenceableError()
1498 <<
"cannot transform a linalg.copy with different source and "
1499 "destination element types ";
1500 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1521 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1522 FailureOr<LowerPackResult> res =
1526 <<
"cannot lower to pad + expand + transpose";
1529 transformResults.
push_back(res->expandShapeOp);
1530 transformResults.
push_back(res->transposeOp);
1543 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1544 FailureOr<LowerUnPackOpResult> res =
1548 emitSilenceableError()
1549 <<
"cannot lower to transpose + collapse + extract";
1550 diag.attachNote(
target->getLoc()) <<
"target payload op";
1553 transformResults.
push_back(res->emptyOp);
1554 transformResults.
push_back(res->transposeOp);
1555 transformResults.
push_back(res->collapseShapeOp);
1556 transformResults.
push_back(res->extractSliceOp);
1567 result.addAttribute(MatchOp::getOpsAttrName(
result.name),
1576 result.addAttribute(MatchOp::getOpsAttrName(
result.name),
1578 result.addTypes(resultTypes);
1586 if (getOps().has_value())
1587 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1590 if (!llvm::hasSingleElement(payloadOps)) {
1595 bool incorrectNumOperandTypes =
false;
1602 if (getInterface().has_value()) {
1603 auto iface = getInterface().value();
1604 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1607 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1608 !isa<TilingInterface>(op))
1610 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1611 !isa<LoopLikeOpInterface>(op))
1616 if (getOpAttrs().has_value()) {
1617 DictionaryAttr opAttrs = getOpAttrs().value();
1619 if (attr.getName() == getInterfaceAttrName() ||
1620 attr.getName() == getOpsAttrName())
1622 if (!op->
hasAttr(attr.getName()))
1624 if (op->
getAttr(attr.getName()) != attr.getValue())
1629 if (getFilterResultType().has_value()) {
1630 Type t = getFilterResultType().value();
1635 if (getFilterOperandTypes().has_value()) {
1636 mlir::ArrayAttr types = getFilterOperandTypes().value();
1639 if (types.size() == 1) {
1642 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1643 Type t = cast<::mlir::Type>(typeattr.getValue());
1645 [&](
Type operandType) { return operandType == t; }))
1650 if (types.size() != operandTypes.size()) {
1651 incorrectNumOperandTypes =
true;
1655 for (
auto [attr, operandType] :
1656 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1657 auto typeattr = cast<mlir::TypeAttr>(attr);
1658 Type type = cast<::mlir::Type>(typeattr.getValue());
1660 if (type != operandType)
1671 (*payloadOps.begin())->walk(matchFun);
1672 if (incorrectNumOperandTypes)
1674 "type, then it must contain as much types as "
1675 "the number of operands in the target ops");
1676 results.
set(cast<OpResult>(getResult()), res);
1691 Type &targetType,
Type &lowSizeType,
1693 Type &splitPointType) {
1694 FunctionType funcType;
1696 if (failed(parser.
parseType<FunctionType>(funcType)))
1699 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1700 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1701 "argument and one result";
1703 targetType = funcType.getInput(0);
1704 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1712 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1713 if (
target.hasDynamicShape()) {
1714 auto diag = emitSilenceableError()
1715 <<
"cannot compute parametric tile sizes for dynamically "
1716 "shaped payload op";
1717 diag.attachNote(
target->getLoc()) <<
"payload op";
1722 target, getDimension(), getTargetSize(), getDivisor());
1724 return emitSilenceableError()
1725 <<
"failed to compute multi-size tiling sizes";
1729 results.
assign(llvm::map_range(
1731 spec->lowTileSize * spec->lowTripCount}),
1732 [&builder,
this](
int64_t value) {
1744 builder,
target, getDimension(), targetSize, divisor);
1746 return emitSilenceableError() <<
"could not generate tile size computation";
1753 {spec->lowTileSize, spec->lowTripCount});
1754 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1755 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1756 assert(lowTileSize && highTileSize && splitPoint &&
1757 "tile sizes are not produced by operations");
1765void transform::MultiTileSizesOp::getEffects(
1769 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1775LogicalResult transform::MultiTileSizesOp::verify() {
1778 return emitOpError() <<
"expects all results type to be the same";
1797 Type linalgOpHType = transform::OperationType::get(
1798 builder.
getContext(), GenericOp::getOperationName());
1817 if (std::empty(targetOps)) {
1818 transformResults.
set(cast<OpResult>(getPackedOp()),
1823 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1824 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1825 return emitSilenceableError()
1826 <<
"requires target to map to exactly 1 LinalgOp (got "
1827 << llvm::range_size(targetOps) <<
")";
1830 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1831 return emitSilenceableError()
1832 <<
"requires number of packed sizes match the number of loops ("
1833 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1840 state, *
this, packedSizes, getMixedPackedSizes());
1843 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1847 transformResults.
set(cast<OpResult>(getPackedOp()),
1848 {maybeResult->packedLinalgOp.getOperation()});
1852void transform::PackOp::getEffects(
1864LogicalResult transform::PackGreedilyOp::verify() {
1866 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1867 <<
" is not a valid permutation";
1870 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1871 for (
auto [s, nmo] :
1872 llvm::zip_equal(getMixedMatmulPackedSizes(),
1873 getMatmulPaddedSizesNextMultipleOf())) {
1876 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1877 return emitOpError() <<
"at most one of the packed_size and the "
1878 "padded_sizes_next_multiple_of can be nonzero "
1879 "for the matmul strategy";
1892 auto linalgOp = dyn_cast<LinalgOp>(op);
1903 getMixedMatmulPackedSizes(),
1905 getMatmulPaddedSizesNextMultipleOf(),
1906 getMatmulInnerDimsOrder());
1907 if (succeeded(packResult)) {
1908 results.push_back(packResult->packedLinalgOp);
1911 results.push_back(linalgOp);
1913 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1919 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1923void transform::PackGreedilyOp::getEffects(
1935LogicalResult transform::PackTransposeOp::verify() {
1938 <<
" is not a valid permutation";
1942 <<
" is not a valid permutation";
1944 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1945 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1946 <<
" or " << getOuterPermAttrName()
1947 <<
" must be specified";
1953enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1963template <
typename RelayoutOpTy>
1964static bool isValidPackingPermutation(
1966 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1968 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1969 "applies to only pack or unpack operations");
1970 if (!op || permutation.empty())
1972 size_t innerRank = op.getInnerDimsPos().size();
1973 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1977 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1978 return permutation.size() == op.getSourceRank() &&
1981 return permutation.size() == op.getDestRank() &&
1989 auto packOrUnpackOps = state.
getPayloadOps(getTargetPackOrUnPackOp());
1992 if (std::empty(packOrUnpackOps)) {
1993 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1994 transformResults.
set(cast<OpResult>(getPackOp()), {});
1995 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
2001 if (!llvm::hasSingleElement(packOrUnpackOps) ||
2002 !llvm::hasSingleElement(linalgOps)) {
2003 return emitSilenceableError()
2004 <<
"requires target to map to exactly 1 "
2005 "packing op and 1 packed op ("
2006 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
2007 << llvm::range_size(linalgOps) <<
")";
2011 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
2012 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
2013 if ((!packOp && !unPackOp)) {
2014 return emitSilenceableError() <<
"requires target to map to a "
2015 "linalg.pack or linalg.unpack";
2017 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
2018 if (!linalgOpTarget)
2019 return emitSilenceableError() <<
"requires a LinalgOp target";
2023 if (packOp && packOp.getResult().hasOneUse())
2024 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
2026 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
2027 if (linalgOp != linalgOpTarget) {
2029 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
2030 : StringLiteral{
"not produced by the LinalgOp target"};
2031 return emitSilenceableError() << errorMsg;
2037 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
2038 OpOperand *packUse = linalgOp.getDpsInitOperand(
2039 cast<OpResult>(unPackOp.getSource()).getResultNumber());
2041 if (!packOp || !packOp.getResult().hasOneUse())
2042 return emitSilenceableError() <<
"could not find matching pack op";
2046 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
2048 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
2049 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
2050 ? StringLiteral{
"invalid outer_perm"}
2051 : StringLiteral{
"invalid inner_perm"};
2052 if (!isValidPackingPermutation(packOp, perm, permType) ||
2053 !isValidPackingPermutation(unPackOp, perm, permType)) {
2055 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
2056 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
2062 assert(packOp && linalgOp &&
"unexpected null op");
2066 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
2068 assert(succeeded(res) &&
"unexpected packTranspose failure");
2071 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
2072 transformResults.
set(cast<OpResult>(getPackedOp()),
2073 {res->transposedLinalgOp});
2075 transformResults.
set(cast<OpResult>(getUnPackOp()),
2076 {res->transposedUnPackOp});
2078 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
2093 StringRef copyBackOp,
2094 bool usePrescribedTensorShapes) {
2095 auto resultType = transform::AnyOpType::get(
b.getContext());
2101 b.getI64ArrayAttr(paddingDimensions),
2104 (padToMultipleOf.empty()
2106 :
b.getDenseI64ArrayAttr(padToMultipleOf)),
2107 b.getI64ArrayAttr(nofoldFlags),
2108 b.getArrayAttr(transposePaddings),
2109 b.getStringAttr(copyBackOp),
2111 usePrescribedTensorShapes ?
b.getUnitAttr() :
nullptr);
2119 StringRef copyBackOp,
2120 bool usePrescribedTensorShapes) {
2121 auto resultType = transform::AnyOpType::get(
b.getContext());
2125 staticPadToMultipleOf);
2131 b.getI64ArrayAttr(paddingDimensions),
2132 dynamicPadToMultipleOf,
2133 staticPadToMultipleOf,
2134 b.getI64ArrayAttr(nofoldFlags),
2135 b.getArrayAttr(transposePaddings),
2137 usePrescribedTensorShapes);
2140void PadOp::getEffects(
2148SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
2150 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(),
b);
2153DiagnosedSilenceableFailure
2154transform::PadOp::apply(transform::TransformRewriter &rewriter,
2155 transform::TransformResults &results,
2156 transform::TransformState &state) {
2157 auto transformOp = cast<TransformOpInterface>(getOperation());
2158 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
2161 auto linalgTarget = dyn_cast<LinalgOp>(
target);
2162 if (!linalgTarget) {
2163 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
2164 diag.attachNote(
target->getLoc()) <<
"target op";
2169 SmallVector<bool> nofoldFlags;
2170 for (int64_t packPadding :
2172 nofoldFlags.push_back(
static_cast<bool>(packPadding));
2175 SmallVector<Attribute> paddingValues;
2176 for (
auto const &[untypedAttr, elementOrTensorType] :
2177 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
2179 if (isa<ub::PoisonAttr>(untypedAttr)) {
2180 paddingValues.push_back(untypedAttr);
2183 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2185 emitOpError(
"expects padding values to be typed attributes or poison");
2190 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2194 if (!parsedAttr || parsedAttr.getType() != elementType) {
2196 << elementType <<
", got " << untypedAttr;
2197 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2200 paddingValues.push_back(parsedAttr);
2204 if (attr.getType() != elementType) {
2206 << elementType <<
", got " << attr;
2207 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2210 paddingValues.push_back(attr);
2214 SmallVector<SmallVector<int64_t>> transposePaddings;
2215 for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2217 cast<ArrayAttr>(transposeVector)));
2224 SmallVector<int64_t> padToMultipleOf;
2226 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2229 if (padToMultipleOf.empty())
2231 SmallVector<int64_t>(
options.paddingDimensions.size(), 1);
2233 options.padToMultipleOf = padToMultipleOf;
2234 options.paddingValues = paddingValues;
2235 options.nofoldFlags = nofoldFlags;
2236 if (getCopyBackOp() ==
2237 bufferization::MaterializeInDestinationOp::getOperationName()) {
2238 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
2239 BufferizationMaterializeInDestination;
2240 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2241 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
2242 }
else if (getCopyBackOp() == kCopyOpNone) {
2243 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
2245 llvm_unreachable(
"unsupported copy_back op");
2248 bool irChanged =
false;
2249 if (getUsePrescribedTensorShapes() &&
2250 linalgTarget.hasPureTensorSemantics()) {
2251 OpBuilder::InsertionGuard g(rewriter);
2253 for (OpOperand &operand : linalgTarget->getOpOperands()) {
2254 for (
auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2255 if (ShapedType::isStatic(dim))
2257 options.setSizeToPadTo(operand.getOperandNumber(), i,
2259 operand.get().getLoc(),
2266 SmallVector<Value> replacements;
2267 SmallVector<tensor::PadOp> newPadOps;
2269 replacements, newPadOps))) {
2275 auto diag = emitSilenceableError() <<
"failed to pad op";
2276 diag.attachNote(
target->getLoc()) <<
"target op";
2285 rewriter.
replaceOp(linalgTarget, replacements);
2286 paddedOps.push_back(paddedOp);
2287 padOps.append(newPadOps.begin(), newPadOps.end());
2288 if (
options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2289 for (Value v : replacements) {
2290 Operation *copyBackOp = v.getDefiningOp();
2291 if (!llvm::is_contained(copyBackOps, copyBackOp))
2292 copyBackOps.push_back(copyBackOp);
2297 results.
set(cast<OpResult>(getPadded()), paddedOps);
2298 results.
set(cast<OpResult>(getPad()), padOps);
2299 results.
set(cast<OpResult>(getCopy()), copyBackOps);
2303LogicalResult transform::PadOp::verify() {
2304 SmallVector<int64_t> nofoldFlags =
2306 if (any_of(nofoldFlags, [](int64_t packPadding) {
2307 return packPadding != 0 && packPadding != 1;
2310 <<
"expects nofold_flags to contain booleans (0/1), found "
2311 << getNofoldFlags();
2314 SmallVector<int64_t> paddingDimensions =
2316 if (any_of(paddingDimensions,
2317 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
2318 return emitOpError() <<
"expects padding_dimensions to contain positive "
2320 << getPaddingDimensions();
2322 if (!getMixedPadToMultipleOf().empty()) {
2323 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2324 return emitOpError() <<
"expects as many multiples as padding_dimensions";
2327 ArrayAttr transposes = getTransposePaddings();
2328 for (Attribute attr : transposes) {
2330 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2331 if (!std::is_permutation(sequence.begin(), sequence.end(),
2332 transpose.begin(), transpose.end())) {
2334 <<
"expects transpose_paddings to be a permutation, found "
2338 if (getCopyBackOp() !=
2339 bufferization::MaterializeInDestinationOp::getOperationName() &&
2340 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2341 getCopyBackOp() != kCopyOpNone)
2350void transform::PadTilingInterfaceOp::build(OpBuilder &
b,
2353 ArrayRef<int64_t> paddingSizes,
2354 bool padToMultipleOf) {
2355 auto resultType = transform::AnyOpType::get(
b.getContext());
2364 :
b.getDenseI64ArrayAttr(paddingSizes)),
2366 padToMultipleOf ?
b.getUnitAttr() :
nullptr);
2369void transform::PadTilingInterfaceOp::build(
2371 ArrayRef<OpFoldResult> mixedPaddingSizes,
bool padToMultipleOf) {
2372 auto resultType = transform::AnyOpType::get(
b.getContext());
2373 SmallVector<int64_t> staticPaddingSizes;
2374 SmallVector<Value> dynamicPaddingSizes;
2376 staticPaddingSizes);
2382 dynamicPaddingSizes,
2387void transform::PadTilingInterfaceOp::getEffects(
2388 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2395SmallVector<OpFoldResult>
2396transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2401DiagnosedSilenceableFailure
2402transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2403 transform::TransformResults &results,
2404 transform::TransformState &state) {
2405 SmallVector<Operation *> paddedOps, padOps;
2408 auto targetOp = dyn_cast<TilingInterface>(
target);
2410 auto diag = emitSilenceableError() <<
"expected TilingInterface target";
2411 diag.attachNote(
target->getLoc()) <<
"target op";
2418 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2419 auto diag = emitSilenceableError() <<
"only IndexingMapOpInterface ops "
2421 diag.attachNote(
target->getLoc()) <<
"target op";
2426 SmallVector<Attribute> paddingValues;
2427 for (
auto const &[untypedAttr, elementOrTensorType] :
2428 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2429 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2432 if (isa<ub::PoisonAttr>(untypedAttr)) {
2433 paddingValues.push_back(untypedAttr);
2437 emitOpError(
"expects padding values to be typed attributes or poison");
2441 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2445 if (!parsedAttr || parsedAttr.getType() != elementType) {
2447 << elementType <<
", got " << attr;
2448 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2451 paddingValues.push_back(parsedAttr);
2455 if (attr.getType() != elementType) {
2457 << elementType <<
", got " << attr;
2458 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2461 paddingValues.push_back(attr);
2465 PadTilingInterfaceOptions
options;
2466 options.setPaddingValues(paddingValues)
2467 .setPaddingSizes(getMixedPaddingSizes())
2468 .setPadToMultipleOf(getPadToMultipleOf());
2470 OpBuilder::InsertionGuard g(rewriter);
2473 rewriter, cast<TilingInterface>(targetOp.getOperation()),
options);
2474 if (
failed(maybePadOps)) {
2475 auto diag = emitSilenceableError() <<
"failed to pad op";
2476 diag.attachNote(
target->getLoc()) <<
"target op";
2479 const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
2482 paddedOps.push_back(paddedOp);
2483 padOps.append(paddedOperands.begin(), paddedOperands.end());
2484 rewriter.
replaceOp(targetOp.getOperation(), slicedResults);
2487 results.
set(cast<OpResult>(getPadded()), paddedOps);
2488 results.
set(cast<OpResult>(getPad()), padOps);
2492LogicalResult transform::PadTilingInterfaceOp::verify() {
return success(); }
2498DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2499 transform::TransformRewriter &rewriter,
2500 transform::TransformResults &transformResults,
2501 transform::TransformState &state) {
2504 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2506 <<
"requires exactly one target and one loop handle (got "
2507 << llvm::range_size(targetOps) <<
" and "
2508 << llvm::range_size(loopOps) <<
")";
2511 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2512 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2513 if (!padOp || !loopOp)
2516 FailureOr<linalg::detail::PackingResult>
result =
2522 if (
result->clonedLoopIvs.empty()) {
2523 transformResults.
set(cast<OpResult>(getPackingLoop()),
2524 {
result->hoistedPadOp.getOperation()});
2527 auto outerPackedLoop =
2529 transformResults.
set(cast<OpResult>(getPackingLoop()),
2530 {outerPackedLoop.getOperation()});
2534LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
2535 ArrayRef<int64_t> transpose = getTranspose();
2536 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2537 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2539 return emitOpError() <<
"expects transpose to be a permutation, found "
2545void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2546 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2553DiagnosedSilenceableFailure
2554transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2556 transform::ApplyToEachResultList &results,
2557 transform::TransformState &state) {
2558 tensor::PadOp hoistedPadOp;
2559 SmallVector<TransposeOp> transposeOps;
2560 FailureOr<Value>
result =
2562 hoistedPadOp, transposeOps);
2573 return emitDefaultSilenceableFailure(
target);
2576LogicalResult transform::HoistPadOp::verify() {
2577 ArrayRef<int64_t> transpose = getTranspose();
2578 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2579 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2581 return emitOpError() <<
"expects transpose to be a permutation, found "
2591DiagnosedSilenceableFailure
2592transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2594 transform::ApplyToEachResultList &results,
2595 transform::TransformState &state) {
2596 LinalgPromotionOptions promotionOptions;
2597 if (!getOperandsToPromote().empty())
2600 if (getUseFullTilesByDefault())
2602 getUseFullTilesByDefault());
2603 if (getUseOriginalSubviewSize())
2607 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2608 if (!getUseFullTileBuffers().empty())
2610 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2611 if (getAlignment().has_value())
2612 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2613 if (getMemorySpace().has_value())
2614 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2616 if (getMapping().has_value()) {
2618 auto mapping = *getMapping();
2619 if (mapping.size() > 1)
2620 return emitDefaultDefiniteFailure(
target);
2622 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2624 if (addressSpace.getAddressSpace() ==
2625 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2632 }
else if (addressSpace.getAddressSpace() ==
2633 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2641 return emitDefaultDefiniteFailure(
target);
2646 return emitDefaultDefiniteFailure(
target);
2651 return emitDefaultDefiniteFailure(
target);
2660DiagnosedSilenceableFailure
2661transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2662 TransformResults &transformResults,
2663 TransformState &state) {
2667 for (Operation *
target : payload) {
2668 if (
target->getNumOperands() > 0)
2670 if (!
target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2671 target->getNumRegions() > 0)
2673 <<
"expected target that is isolated from above";
2677 Operation *pattern = &getBodyRegion().front().front();
2678 SmallVector<Operation *> replacements;
2679 for (Operation *
target : payload) {
2680 if (getOperation()->isAncestor(
target))
2687 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2691void transform::ReplaceOp::getEffects(
2692 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2698LogicalResult transform::ReplaceOp::verify() {
2699 if (!getBodyRegion().hasOneBlock())
2701 if (std::distance(getBodyRegion().front().begin(),
2702 getBodyRegion().front().end()) != 1)
2703 return emitOpError() <<
"expected one operation in block";
2704 Operation *
replacement = &getBodyRegion().front().front();
2707 <<
"expected replacement without operands";
2708 if (!
replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2711 <<
"expect op that is isolated from above";
2719DiagnosedSilenceableFailure
2720transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2722 transform::ApplyToEachResultList &results,
2723 transform::TransformState &state) {
2724 scf::SCFTilingOptions tilingOptions;
2725 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &
b, Operation *) {
2726 SmallVector<OpFoldResult> tileSizes;
2727 Location loc =
target.getLoc();
2728 SmallVector<OpFoldResult> allShapeSizes =
2729 target.createFlatListOfOperandDims(
b, loc);
2730 AffineMap map =
target.getShapesToLoopsMap();
2733 SmallVector<OpFoldResult> shapeSizes =
2738 for (OpFoldResult shapeSize : shapeSizes) {
2740 :
b.getIndexAttr(1));
2745 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2746 rewriter, cast<TilingInterface>(
target.getOperation()), tilingOptions);
2747 if (
failed(maybeTilingResult))
2748 return emitDefaultDefiniteFailure(
target);
2750 if (
target->getNumResults())
2755 results.
reserve(maybeTilingResult->tiledOps.size());
2756 for (Operation *tiled : maybeTilingResult->tiledOps)
2765DiagnosedSilenceableFailure
2766transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2767 transform::TransformResults &results,
2768 transform::TransformState &state) {
2769 SmallVector<Operation *> loops;
2771 auto tilingOp = dyn_cast<TilingInterface>(*
target);
2773 DiagnosedSilenceableFailure
diag =
2774 emitSilenceableError()
2775 <<
"expected the payload to implement TilingInterface";
2776 diag.attachNote(
target->getLoc()) <<
"payload op";
2780 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2781 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2782 if (
failed(generatedLoops))
2783 return emitDefaultDefiniteFailure(
target);
2784 for (scf::ForOp &loop : *generatedLoops) {
2785 loops.push_back(loop.getOperation());
2789 results.
set(cast<OpResult>(getResult()), loops);
2797DiagnosedSilenceableFailure
2798transform::RewriteInDestinationPassingStyleOp::applyToOne(
2799 transform::TransformRewriter &rewriter, Operation *
target,
2800 transform::ApplyToEachResultList &results,
2801 transform::TransformState &state) {
2803 FailureOr<Operation *> maybeResult =
2805 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2806 [&rewriter](
auto op) {
2810 return emitDefaultSilenceableFailure(
target);
2819DiagnosedSilenceableFailure
2820SplitOp::apply(transform::TransformRewriter &rewriter,
2821 TransformResults &results, TransformState &state) {
2823 SmallVector<Operation *> payload =
2826 bool isMultiwaySplit = getMultiway();
2828 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2830 <<
"requires exactly one target when "
2831 "multiway split is enabled (got "
2832 << llvm::range_size(payload) <<
")";
2835 SmallVector<OpFoldResult> chunkSizes;
2837 if (!isMultiwaySplit)
2838 chunkSizes.reserve(payload.size());
2840 if (getDynamicChunkSizes()) {
2842 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2843 chunkSizes = llvm::to_vector(llvm::map_range(
2844 state.
getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2847 diag = emitSilenceableError()
2848 <<
"expected dynamic split point handle to point to a "
2849 "single-result index-typed op";
2850 diag.attachNote(op->
getLoc()) <<
"dynamic split point";
2855 chunkSizes = llvm::to_vector(
2856 llvm::map_range(state.
getParams(getDynamicChunkSizes()),
2857 [](Attribute attr) {
return OpFoldResult(attr); }));
2859 if (
diag.isSilenceableFailure())
2864 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2866 <<
"expected the dynamic split point handle to point to as "
2868 << chunkSizes.size() <<
") as the target handle ("
2869 << payload.size() <<
")";
2872 chunkSizes.resize(payload.size(),
2876 auto checkStructuredOpAndDimensions =
2877 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2879 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2880 diag.attachNote(loc) <<
"target op";
2884 if (getDimension() >= linalgOp.getNumLoops()) {
2885 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2886 <<
" does not exist in target op";
2887 diag.attachNote(loc) <<
"target op";
2893 auto checkFailureInSplitting =
2894 [&](
bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2903 SmallVector<Operation *> opList;
2904 if (isMultiwaySplit) {
2907 TilingInterface head, tail;
2908 Operation *
target = payload.front();
2910 LinalgOp linalgOp = dyn_cast<LinalgOp>(
target);
2913 DiagnosedSilenceableFailure
diag =
2914 checkStructuredOpAndDimensions(linalgOp,
target->getLoc());
2915 if (
diag.isSilenceableFailure())
2918 for (
auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2921 target = tail.getOperation();
2926 linalgOp = cast<LinalgOp>(
target);
2927 Location loc =
target->getLoc();
2931 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2932 getDimension(), chunkSize);
2935 DiagnosedSilenceableFailure
diag =
2936 checkFailureInSplitting(!head && !tail, loc);
2937 if (
diag.isDefiniteFailure())
2940 opList.push_back(head.getOperation());
2945 opList.push_back(tail.getOperation());
2949 SmallVector<Operation *> first, second;
2950 Operation *noSecondPart =
nullptr;
2951 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2952 Operation *
target = std::get<0>(pair);
2953 Location loc =
target->getLoc();
2954 LinalgOp linalgOp = dyn_cast<LinalgOp>(
target);
2955 DiagnosedSilenceableFailure
diag =
2956 checkStructuredOpAndDimensions(linalgOp,
target->getLoc());
2958 if (
diag.isSilenceableFailure())
2962 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2963 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2964 getDimension(), std::get<1>(pair));
2967 DiagnosedSilenceableFailure diagSplit =
2968 checkFailureInSplitting(!first.back() && !second.back(), loc);
2973 if (!second.back()) {
2979 if (second.size() != first.size() && !second.empty()) {
2980 auto diag = emitSilenceableError()
2981 <<
"splitting does not produce the second part for a subset "
2984 <<
"expected splitting to produce the second part of all "
2985 "or none of the targets";
2987 <<
"first target with no second part";
2991 opList.append(first);
2992 if (!second.empty())
2993 opList.append(second);
2995 results.
set(cast<OpResult>(getSplitList()), opList);
2999void SplitOp::getEffects(
3000 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3002 if (getDynamicChunkSizes())
3008ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &
result) {
3009 OpAsmParser::UnresolvedOperand
target, dynamicChunkSizes;
3010 IntegerAttr staticChunkSizes;
3014 OptionalParseResult dynamicPointParseResult =
3016 if (!dynamicPointParseResult.
has_value()) {
3017 int64_t staticChunkSizesValue;
3031 if (dynamicPointParseResult.
has_value()) {
3032 Type chunkSizesType;
3045 SplitOp::getStaticChunkSizesAttrName(
result.name).getValue(),
3047 result.addTypes(targetType);
3051void SplitOp::print(OpAsmPrinter &printer) {
3052 printer <<
" " << getTarget() <<
" after ";
3053 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
3054 if (staticChunkSize != ShapedType::kDynamic)
3055 printer << staticChunkSize;
3057 printer << getDynamicChunkSizes();
3060 {getStaticChunkSizesAttrName()});
3061 printer <<
" : " << getTarget().getType();
3062 if (staticChunkSize == ShapedType::kDynamic)
3063 printer <<
", " << getDynamicChunkSizes().getType();
3066LogicalResult SplitOp::verify() {
3067 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
3068 (getDynamicChunkSizes() ==
nullptr)) {
3069 return emitOpError() <<
"expects either a dynamic or a static split "
3070 "point to be provided";
3079void transform::SplitReductionOp::build(
3080 OpBuilder &builder, OperationState &
result, Value
target,
3081 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
3082 bool useScalingAlgorithm,
bool useAlloc) {
3085 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(
result.name),
3088 SplitReductionOp::getInsertSplitDimensionAttrName(
result.name),
3090 if (innerParallel) {
3091 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(
result.name),
3094 if (useScalingAlgorithm) {
3096 SplitReductionOp::getUseScalingAlgorithmAttrName(
result.name),
3100 result.addAttribute(SplitReductionOp::getUseAllocAttrName(
result.name),
3103 auto resultType = transform::AnyOpType::get(ctx);
3104 result.addTypes({resultType, resultType, resultType, resultType});
3107DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
3108 transform::TransformRewriter &rewriter, LinalgOp
target,
3109 transform::ApplyToEachResultList &results,
3110 transform::TransformState &state) {
3112 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
3113 unsigned(getInsertSplitDimension()),
3114 bool(getInnerParallel())};
3117 FailureOr<SplitReductionResult> splitResult =
3118 (getUseScalingAlgorithm())
3122 return emitDefaultDefiniteFailure(
target);
3124 results.
push_back(splitResult->initOrAlloc);
3126 results.
push_back(splitResult->splitLinalgOp);
3127 results.
push_back(splitResult->resultCombiningLinalgOp);
3135void transform::TileReductionUsingForOp::build(
3136 OpBuilder &builder, OperationState &
result, Value
target,
3137 ArrayRef<int64_t> staticTileSizes) {
3144 auto opTy = transform::AnyOpType::get(ctx);
3150 staticTileSizesAttr);
3153DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
3154 transform::TransformRewriter &rewriter, Operation *
target,
3155 transform::ApplyToEachResultList &results,
3156 transform::TransformState &state) {
3159 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(
target);
3160 if (!partialReductionOp) {
3163 "Operation should implement PartialReductionOpInterface");
3166 SmallVector<unsigned> reductionDims =
3168 if (reductionDims.empty()) {
3169 for (
auto [idx, iteratorType] :
3170 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3171 if (iteratorType == utils::IteratorType::reduction)
3172 reductionDims.push_back(idx);
3176 scf::SCFTilingOptions
options;
3177 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
3178 options.setReductionTilingStrategy(
3181 options.setReductionDims(reductionDims);
3182 FailureOr<scf::SCFTilingResult>
result =
3183 scf::tileUsingSCF(rewriter, partialReductionOp,
options);
3187 "failed to tile using partial reduction");
3190 for (Value initValue :
result->initialValues)
3192 for (
auto parallelTiledOp :
result->tiledOps)
3194 for (
auto mergeOp :
result->mergeOps)
3204void transform::TileReductionUsingForallOp::build(
3205 OpBuilder &builder, OperationState &
result, Value
target,
3206 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3214 auto opTy = transform::AnyOpType::get(ctx);
3221 staticNumThreadsAttr,
3222 staticTileSizesAttr,
3226DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3227 transform::TransformRewriter &rewriter, Operation *
target,
3228 transform::ApplyToEachResultList &results,
3229 transform::TransformState &state) {
3232 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(
target);
3233 if (!partialReductionOp) {
3236 "Operation should implement PartialReductionOpInterface");
3238 SmallVector<OpFoldResult> numThreads =
3240 SmallVector<OpFoldResult> tileSizes =
3243 scf::SCFTilingOptions
options;
3244 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3245 options.setReductionTilingStrategy(
3247 if (!getNumThreads().empty()) {
3248 options.setNumThreads(numThreads);
3250 options.setTileSizes(tileSizes);
3252 if (
auto mapping = getMapping()) {
3253 options.setMapping(mapping.value().getValue());
3255 SmallVector<unsigned> reductionDims =
3257 if (reductionDims.empty()) {
3258 for (
auto [idx, iteratorType] :
3259 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3260 if (iteratorType == utils::IteratorType::reduction)
3261 reductionDims.push_back(idx);
3264 options.setReductionDims(reductionDims);
3265 FailureOr<scf::SCFTilingResult>
result =
3266 scf::tileUsingSCF(rewriter, partialReductionOp,
options);
3269 auto diag = emitSilenceableError() <<
"could not tile reduction";
3274 for (Value initValue :
result->initialValues)
3276 for (
auto parallelTiledOp :
result->tiledOps)
3278 for (
auto mergeOp :
result->mergeOps)
3288DiagnosedSilenceableFailure
3289transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3290 TransformResults &transformResults,
3291 TransformState &state) {
3293 SmallVector<Operation *> targetOps =
3296 if (!llvm::hasSingleElement(targetOps)) {
3298 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
3302 Operation *
target = *targetOps.begin();
3303 auto linalgOp = dyn_cast<LinalgOp>(
target);
3304 auto tileableOp = dyn_cast<TilingInterface>(
target);
3309 OpBuilder builder(linalgOp.getContext());
3311 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
3312 if (linalgOp.hasDynamicShape()) {
3313 auto diag = emitSilenceableError()
3314 <<
"cannot compute parametric tile sizes for dynamically "
3315 "shaped payload op";
3316 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
3320 FailureOr<StaticContinuousTileSizeSpecification> spec =
3324 return emitSilenceableError()
3325 <<
"failed to compute multi-size tiling sizes";
3328 SmallVector<int64_t> chunkSizes;
3330 for (
auto &&[tileSize, tripCount] :
3331 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3332 chunkSizes.push_back(tileSize * tripCount);
3334 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3335 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3340 getI64AttrsFromI64(spec->tileSizes));
3341 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
3342 getI64AttrsFromI64(chunkSizes));
3349 OpFoldResult targetSize = builder.
getIndexAttr(getTargetSize());
3350 unsigned dimension = getDimension();
3353 builder, tileableOp, dimension, targetSize,
true);
3355 return emitSilenceableError() <<
"could not generate tile size computation";
3360 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3365 SmallVector<Value> chunkSizes;
3367 for (
auto &&[tileSize, tripCount] :
3368 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3369 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3370 chunkSizes.push_back(splitPoint);
3373 auto getDefiningOps = [&](ArrayRef<Value> values) {
3374 return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3380 getDefiningOps(spec->tileSizes));
3381 transformResults.
set(cast<OpResult>(getChunkSizes()),
3382 getDefiningOps(chunkSizes));
3387LogicalResult transform::ContinuousTileSizesOp::verify() {
3390 return emitOpError() <<
"expects all results type to be the same";
3396void transform::ContinuousTileSizesOp::getEffects(
3397 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3414 Type &tileSizesType,
3415 Type &chunkSizesType) {
3416 FunctionType funcType;
3418 if (failed(parser.
parseType<FunctionType>(funcType)))
3421 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3422 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
3423 "argument and one result";
3425 targetType = funcType.getInput(0);
3426 tileSizesType = chunkSizesType = funcType.getResult(0);
3435void transform::TileUsingForOp::build(
3437 Value
target, ArrayRef<int64_t> staticTileSizes,
3438 ArrayRef<int64_t> interchange,
3439 std::optional<ArrayRef<bool>> scalableSizes) {
3440 return build(builder,
result, loopTypes,
3444 interchange, scalableSizes);
3447void transform::TileUsingForOp::build(
3448 OpBuilder &builder, OperationState &
result, Value
target,
3449 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3450 std::optional<ArrayRef<bool>> scalableSizes) {
3453 interchange, scalableSizes);
3456void transform::TileUsingForOp::build(
3457 OpBuilder &builder, OperationState &
result, Value
target,
3458 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3459 std::optional<ArrayRef<bool>> scalableSizes) {
3462 SmallVector<Type> loopTypes(1, builder.
getType<transform::AnyOpType>());
3463 build(builder,
result, loopTypes,
target, mixedTileSizes, interchange,
3467void transform::TileUsingForOp::build(
3469 Value
target, ArrayRef<OpFoldResult> mixedTileSizes,
3470 ArrayRef<int64_t> interchange,
3471 std::optional<ArrayRef<bool>> scalableSizes) {
3472 SmallVector<int64_t> staticTileSizes;
3473 SmallVector<Value> dynamicTileSizes;
3479 unsigned numExpectedLoops =
3480 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3481 SmallVector<Type> resultTypes;
3482 resultTypes.reserve(numExpectedLoops);
3483 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3484 "expected one loop type or as many as loops");
3485 if (loopTypes.size() == 1)
3486 resultTypes.append(numExpectedLoops, loopTypes[0]);
3488 llvm::append_range(resultTypes, loopTypes);
3489 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(),
false);
3490 if (scalableSizes.has_value())
3491 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3496 staticTileSizesAttr,
3498 expandedScalableSizes);
3501LogicalResult transform::TileUsingForOp::verify() {
3503 return emitOpError(
"expected same number of sizes (")
3505 << getScalableSizes().size() <<
")";
3506 ArrayRef<int64_t> staticSizes = getStaticSizes();
3507 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3508 if (getLoops().size() != numExpectedLoops)
3509 return emitOpError(
"expected number of loops to tile (")
3510 << numExpectedLoops <<
") to match number of `loops` results ("
3511 << getLoops().size() <<
")";
3515DiagnosedSilenceableFailure
3516transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3517 TransformResults &transformResults,
3518 TransformState &state) {
3519 ArrayRef<int64_t> tileSizes = getStaticSizes();
3521 SmallVector<Operation *> targets =
3523 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3524 SmallVector<SmallVector<int64_t>> paramSizes;
3528 if (isa<ParamType>(transformValue.getType())) {
3529 dynamicSizeProducers.push_back({});
3530 ArrayRef<Attribute> params = state.
getParams(transformValue);
3531 paramSizes.push_back(
3532 llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
3533 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3536 if (paramSizes.back().size() != targets.size()) {
3537 DiagnosedSilenceableFailure
diag =
3538 emitSilenceableError()
3539 <<
"expected as many parameter values ("
3540 << dynamicSizeProducers.back().size() <<
") as target ops ("
3541 << targets.size() <<
")";
3542 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
3548 paramSizes.push_back({});
3549 dynamicSizeProducers.push_back(
3552 if (dynamicSizeProducers.back().size() != targets.size()) {
3553 DiagnosedSilenceableFailure
diag =
3554 emitSilenceableError()
3555 <<
"expected as many dynamic size-producing operations ("
3556 << dynamicSizeProducers.back().size() <<
") as target ops ("
3557 << targets.size() <<
")";
3558 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3562 for (Operation *op : dynamicSizeProducers.back()) {
3568 DiagnosedSilenceableFailure
diag =
3569 emitSilenceableError() <<
"expected sizes to be produced by ops "
3570 "with a single index-type result";
3571 diag.attachNote(op->
getLoc()) <<
"size producer op";
3572 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3577 SmallVector<Operation *> tiled;
3578 SmallVector<SmallVector<Operation *, 4>, 4> loops;
3579 loops.resize(getLoops().size());
3580 auto scalableSizes = getScalableSizes();
3581 for (
auto [i, op] : llvm::enumerate(targets)) {
3582 auto tilingInterface = dyn_cast<TilingInterface>(op);
3583 if (!tilingInterface) {
3584 DiagnosedSilenceableFailure
diag =
3585 emitSilenceableError()
3586 <<
"only ops implementing TilingInterface are supported";
3587 diag.attachNote(op->
getLoc()) <<
"target op";
3590 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3591 DiagnosedSilenceableFailure
diag =
3592 emitSilenceableError()
3593 <<
"too many tiles provided, expected at most "
3594 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
3595 << tileSizes.size();
3596 diag.attachNote(op->
getLoc()) <<
"target op";
3600 scf::SCFTilingOptions tilingOptions;
3601 if (tileSizes.empty()) {
3602 tilingOptions.setTileSizeComputationFunction(
3603 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
3607 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &
b,
3609 SmallVector<OpFoldResult> sizes;
3610 sizes.reserve(tileSizes.size());
3611 unsigned dynamicIdx = 0;
3613 for (
auto [ofrIdx, ofr] : llvm::enumerate(
getMixedSizes())) {
3614 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3615 if (scalableSizes[ofrIdx]) {
3617 b, getLoc(), cast<IntegerAttr>(attr).getInt());
3619 vector::VectorScaleOp::create(
b, getLoc(),
b.getIndexType());
3621 arith::MulIOp::create(
b, getLoc(), val, vscale).getResult());
3623 sizes.push_back(attr);
3627 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3628 ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3630 assert((dynamicSizes.empty() ^ params.empty()) &&
3631 "expected either dynamic sizes or parameters");
3632 if (!params.empty()) {
3633 sizes.push_back(
b.getIndexAttr(params[index]));
3635 sizes.push_back(dynamicSizes[index]->getResult(0));
3642 tilingOptions.setInterchange(getInterchange());
3643 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3644 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3645 if (
failed(maybeTilingResult))
3648 rewriter.
replaceOp(op, maybeTilingResult->replacements);
3650 tiled.append(maybeTilingResult->tiledOps);
3651 for (
const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3652 loops[en2.index()].push_back(en2.value());
3655 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3656 for (
const auto &en : llvm::enumerate(loops))
3657 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3662SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
3664 ArrayRef<int64_t> tileSizes = getStaticSizes();
3665 SmallVector<OpFoldResult> results;
3666 results.reserve(tileSizes.size());
3667 unsigned dynamicPos = 0;
3669 for (int64_t size : tileSizes) {
3670 if (size == ShapedType::kDynamic) {
3671 results.push_back(dynamic[dynamicPos++]);
3679void transform::TileUsingForOp::getEffects(
3680 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3691void transform::TileUsingForallOp::build(OpBuilder &builder,
3693 ArrayRef<int64_t> staticTileSizes,
3694 transform::TileSizesSpec,
3696 return build(builder,
result,
3704void transform::TileUsingForallOp::build(OpBuilder &builder,
3706 ArrayRef<OpFoldResult> mixedTileSizes,
3707 transform::TileSizesSpec,
3709 SmallVector<int64_t> staticTileSizes;
3710 SmallVector<Value> dynamicTileSizes;
3716 auto operationType = transform::AnyOpType::get(ctx);
3719 TypeRange{operationType, operationType},
3726 staticTileSizesAttr,
3730void transform::TileUsingForallOp::build(OpBuilder &builder,
3732 ArrayRef<int64_t> staticNumThreads,
3733 transform::NumThreadsSpec,
3737 NumThreadsSpec(), mapping);
3740void transform::TileUsingForallOp::build(OpBuilder &builder,
3742 ArrayRef<OpFoldResult> mixedNumThreads,
3743 transform::NumThreadsSpec,
3745 SmallVector<int64_t> staticNumThreads;
3746 SmallVector<Value> dynamicNumThreads;
3753 auto operationType = transform::AnyOpType::get(ctx);
3756 TypeRange{operationType, operationType},
3762 staticNumThreadsAttr,
3769static SmallVector<OpFoldResult>
3775 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3777 for (
auto [lb,
ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3779 rewriter, loc, normalizedUbExpr, {lb,
ub, step});
3780 normalizedUbs.push_back(normalizedUb);
3782 return normalizedUbs;
3798 for (
auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3801 denormalizedIvs.push_back(
3804 return denormalizedIvs;
3815 scf::ForallOp loop) {
3832 auto normalizedForallOp = scf::ForallOp::create(
3833 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3834 loop.getOutputs(), loop.getMapping(),
3837 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3839 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3844 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3845 normalizedForallOp.getRegionIterArgs().end());
3846 Block *origLoopBlock = loop.getBody();
3847 rewriter.
mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3849 rewriter.
replaceOp(loop, normalizedForallOp);
3850 return normalizedForallOp;
3858 scf::SCFTilingResult &tilingResult) {
3860 auto tileableOp = dyn_cast<TilingInterface>(
target);
3863 transformOp.emitSilenceableError()
3864 <<
"only TilingInterface ops are supported";
3865 diag.attachNote(
target->getLoc()) <<
"target op";
3869 scf::SCFTilingOptions
options;
3870 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3871 if (!mixedNumThreads.empty()) {
3872 options.setNumThreads(mixedNumThreads);
3874 options.setTileSizes(mixedTileSizes);
3877 options.setMapping(mapping.value().getValue());
3879 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3880 scf::tileUsingSCF(rewriter, tileableOp,
options);
3882 if (failed(maybeTilingResult))
3883 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3885 rewriter.
replaceOp(tileableOp, maybeTilingResult->replacements);
3887 tilingResult = *maybeTilingResult;
3889 if (mixedNumThreads.empty()) {
3890 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3893 scf::ForallOp normalizedForallOp =
3895 tilingResult.loops.front() = normalizedForallOp;
3905 auto transformOp = cast<TransformOpInterface>(getOperation());
3914 getPackedNumThreads()
3916 state, transformOp, mixedNumThreads, getPackedNumThreads())
3918 state, transformOp, mixedNumThreads, getMixedNumThreads());
3922 status = getPackedTileSizes()
3924 state, transformOp, mixedTileSizes, getPackedTileSizes())
3926 state, transformOp, mixedTileSizes, getMixedTileSizes());
3931 scf::SCFTilingResult tilingResult;
3933 rewriter, state, transformOp,
target, mixedNumThreads, mixedTileSizes,
3934 getMapping(), tilingResult);
3935 if (!
diag.succeeded())
3937 tileOps.push_back(tilingResult.loops.front());
3938 tiledOps.append(tilingResult.tiledOps);
3941 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3942 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3947void transform::TileUsingForallOp::getEffects(
3948 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3958SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3963SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3968LogicalResult TileUsingForallOp::verify() {
3969 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3970 static_cast<int>(getPackedNumThreads() != Value());
3971 if (numThreadsSpec > 1)
3973 "num_threads and packed_num_threads are mutually exclusive");
3974 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3975 static_cast<int>(getPackedTileSizes() != Value());
3976 if (tileSizesSpec > 1)
3978 "tile_sizes and packed_tile_sizes are mutually exclusive");
3979 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3980 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
3981 "must be specified");
3989void transform::VectorizeChildrenAndApplyPatternsOp::build(
3990 OpBuilder &builder, OperationState &
result, Value
target,
3991 bool foldTypeExtensionsIntoContract,
bool vectorizePadding,
3992 bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
3994 if (foldTypeExtensionsIntoContract) {
3996 VectorizeChildrenAndApplyPatternsOp::
3997 getFoldTypeExtensionsIntoContractAttrName(
result.name),
4000 if (vectorizePadding) {
4002 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
4006 if (vectorizeExtract) {
4008 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
4012 if (flatten1DDepthwiseConv) {
4014 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
4024struct VectorizationPattern :
public RewritePattern {
4025 explicit VectorizationPattern(MLIRContext *context,
4026 bool vectorizeExtract =
false,
4027 bool flattenConv =
false)
4028 : RewritePattern(MatchAnyOpTypeTag(), 1, context),
4029 vectorizeNDExtract(vectorizeExtract),
4030 flatten1DDepthwiseConv(flattenConv) {}
4031 LogicalResult matchAndRewrite(Operation *op,
4032 PatternRewriter &rewriter)
const override {
4035 "Unsupported Op, cannot vectorize");
4036 FailureOr<VectorizationResult> vectorResults =
4038 {}, vectorizeNDExtract,
4039 flatten1DDepthwiseConv);
4040 if (
failed(vectorResults))
4042 rewriter.
replaceOp(op, vectorResults->replacements);
4049 bool vectorizeNDExtract =
false;
4053 bool flatten1DDepthwiseConv =
false;
4057DiagnosedSilenceableFailure
4058transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
4059 transform::TransformRewriter &rewriter, Operation *
target,
4060 transform::ApplyToEachResultList &results,
4061 transform::TransformState &state) {
4062 if (!
target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
4063 auto diag = this->
emitOpError(
"requires isolated-from-above targets");
4064 diag.attachNote(
target->getLoc()) <<
"non-isolated target";
4070 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
4071 getFlatten_1dDepthwiseConv());
4073 if (!getDisableTransferPermutationMapLoweringPatterns())
4076 if (!getDisableMultiReductionToContractPatterns())
4081 patterns.add<linalg::LinalgCopyVTRForwardingPattern,
4082 linalg::LinalgCopyVTWForwardingPattern>(ctx,
4084 vector::TransferReadOp::getCanonicalizationPatterns(
patterns, ctx);
4085 vector::TransferWriteOp::getCanonicalizationPatterns(
patterns, ctx);
4088 patterns.add<CopyVectorizationPattern>(ctx);
4090 if (getFoldTypeExtensionsIntoContract())
4093 if (getVectorizePadding()) {
4101 TrackingListener listener(state, *
this);
4104 GreedyRewriteConfig().setListener(&listener))))
4105 return emitDefaultDefiniteFailure(
target);
4115DiagnosedSilenceableFailure transform::VectorizeOp::apply(
4116 transform::TransformRewriter &rewriter,
4117 mlir::transform::TransformResults &transformResults,
4118 mlir::transform::TransformState &state) {
4120 if (std::empty(targets))
4122 auto transformOp = cast<TransformOpInterface>(getOperation());
4123 SmallVector<int64_t> vectorSizes;
4125 state, transformOp, getMixedVectorSizes(), vectorSizes);
4130 for (Operation *
target : targets) {
4133 <<
"Unsupported Op, cannot vectorize";
4135 FailureOr<VectorizationResult> vectorResults =
4137 getVectorizeNdExtract().value_or(
false),
4139 getAssumeDynamicDimsMatchVecSizes().value_or(
false),
4140 getCreateNamedContraction().value_or(
false));
4141 if (
failed(vectorResults)) {
4143 <<
"Attempted to vectorize, but failed";
4151void transform::VectorizeOp::getEffects(
4152 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4158SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
4163LogicalResult transform::VectorizeOp::verify() {
4164 if (getStaticVectorSizes().size() != getScalableSizes().size())
4165 return emitOpError(
"expected same number of vector sizes (")
4166 << getStaticVectorSizes().size() <<
") and scalable sizes ("
4167 << getScalableSizes().size() <<
")";
4175DiagnosedSilenceableFailure
4176transform::HoistRedundantVectorTransfersOp::applyToOne(
4177 transform::TransformRewriter &rewriter, func::FuncOp
target,
4178 transform::ApplyToEachResultList &results,
4179 transform::TransformState &state) {
4192DiagnosedSilenceableFailure
4193transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4194 transform::TransformRewriter &rewriter, mlir::Operation *
target,
4195 transform::ApplyToEachResultList &results,
4196 transform::TransformState &state) {
4207DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
4208 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4209 transform::ApplyToEachResultList &results,
4210 transform::TransformState &state) {
4212 auto maybeTransformed =
4215 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4218 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4221 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4224 .Case([&](linalg::Conv2DNchwFchwOp op) {
4227 .Default([&](Operation *op) {
4230 if (
failed(maybeTransformed))
4231 return emitDefaultSilenceableFailure(
target);
4233 results.
push_back(maybeTransformed->first);
4235 results.
push_back(maybeTransformed->second);
4243DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4244 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4245 transform::ApplyToEachResultList &results,
4246 transform::TransformState &state) {
4250 <<
"only elementwise flattening is supported";
4253 if (
target.getNumLoops() <= 1) {
4260 std::iota(reassociation.begin(), reassociation.end(), 0);
4261 auto maybeFlattened =
4263 if (
failed(maybeFlattened))
4265 <<
"attempted to flatten, but failed";
4266 results.
push_back(maybeFlattened->collapsedOp);
4275DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4276 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4277 transform::ApplyToEachResultList &results,
4278 transform::TransformState &state) {
4280 auto maybeTransformed =
4282 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4285 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4288 .Default([&](Operation *op) {
4291 if (
failed(maybeTransformed))
4292 return emitDefaultSilenceableFailure(
target);
4302DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4303 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4304 transform::ApplyToEachResultList &results,
4305 transform::TransformState &state) {
4307 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4308 auto maybeTransformed =
4310 .Case([&](linalg::MatmulOp op) {
4313 .Case([&](linalg::BatchMatmulOp op) {
4316 .Default([&](Operation *op) {
return failure(); });
4317 if (
failed(maybeTransformed))
4327template <
typename OpTy>
4328static DiagnosedSilenceableFailure
4332 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4333 tensor::ParallelInsertSliceOp>() &&
4336 if (
auto copySource =
4337 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4345 if (isa<mlir::ParallelCombiningOpInterface>(
target.getOperation()))
4348 Value extracted = tensor::ExtractSliceOp::create(
4351 Value copied = linalg::CopyOp::create(rewriter,
target.getLoc(),
4352 target.getSource(), extracted)
4364DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4365 transform::TransformRewriter &rewriter, Operation *targetOp,
4366 transform::ApplyToEachResultList &results,
4367 transform::TransformState &state) {
4370 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4371 return doit(rewriter,
target, results, state);
4372 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4373 return doit(rewriter,
target, results, state);
4375 DiagnosedSilenceableFailure
diag =
4376 emitSilenceableError()
4377 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4378 diag.attachNote(targetOp->
getLoc()) <<
"target op";
4386DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4387 transform::TransformRewriter &rewriter, Operation *
target,
4388 transform::ApplyToEachResultList &results,
4389 transform::TransformState &state) {
4391 if (!isa<linalg::CopyOp, tensor::PadOp>(
target)) {
4392 DiagnosedSilenceableFailure
diag =
4393 emitSilenceableError()
4394 <<
"only linalg.copy and tensor.pad target ops are supported";
4395 diag.attachNote(
target->getLoc()) <<
"target op";
4398 assert(
target->getNumResults() == 1 &&
"expected single result");
4399 auto resultShapedType = cast<ShapedType>(
target->getResult(0).getType());
4400 if (!resultShapedType.hasStaticShape()) {
4401 DiagnosedSilenceableFailure
diag =
4402 emitSilenceableError()
4403 <<
"only statically sized ops of rank <= 3 are supported";
4404 diag.attachNote(
target->getLoc()) <<
"target op";
4409 int64_t desiredBitAlignment = getDesiredBitAlignment();
4410 int64_t eltBitwidth =
4411 resultShapedType.getElementType().getIntOrFloatBitWidth();
4412 if (desiredBitAlignment % eltBitwidth != 0) {
4413 desiredBitAlignment = eltBitwidth;
4416 gpu::CopyMappingInfo mapping(
4418 getTotalNumThreads(),
4419 desiredBitAlignment,
4420 resultShapedType.getShape(),
4423 resultShapedType.getElementType().getIntOrFloatBitWidth());
4424 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4425 DiagnosedSilenceableFailure
diag =
4426 emitSilenceableError()
4427 <<
"too few threads to map copy op to threads on the most minor "
4428 "dimension, given alignment and vector size constraints, try "
4429 "smaller tile size of mapping to more threads";
4430 diag.attachNote(
target->getLoc()) <<
"target op";
4436 scf::SCFTilingResult tilingResult;
4443 ArrayRef<OpFoldResult>{},
4444 b.getArrayAttr(mapping.threadMapping),
4446 if (!
diag.succeeded())
4449 results.
push_back(tilingResult.loops.front());
4450 for (
auto op : tilingResult.tiledOps)
4459DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4460 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4461 transform::ApplyToEachResultList &results,
4462 transform::TransformState &state) {
4464 FailureOr<Operation *> maybeTransformed = failure();
4466 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4471 .Default([&](Operation *op) {
return false; });
4474 return emitSilenceableError()
4475 <<
"this operation is not supported to convert to Winograd Conv2D";
4478 if (
failed(maybeTransformed)) {
4479 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
4486DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4487 transform::TransformRewriter &rewriter, Operation *
target,
4488 transform::ApplyToEachResultList &results,
4489 transform::TransformState &state) {
4491 FailureOr<Operation *> maybeTransformed = failure();
4494 .Case([&](linalg::WinogradFilterTransformOp op) {
4498 .Case([&](linalg::WinogradInputTransformOp op) {
4502 .Case([&](linalg::WinogradOutputTransformOp op) {
4509 DiagnosedSilenceableFailure
diag =
4510 emitSilenceableError()
4511 <<
"this operation is not supported to decompose into other operations";
4512 diag.attachNote(
target->getLoc()) <<
"target op";
4516 if (
failed(maybeTransformed)) {
4517 DiagnosedSilenceableFailure
diag =
4518 emitSilenceableError() <<
"decompose Winograd operations failed";
4519 diag.attachNote(
target->getLoc()) <<
"target op";
4527#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4529#define GET_OP_CLASSES
4530#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static llvm::ManagedStatic< PassManagerOptions > options
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the error.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SmallVector< IntTy > extractFromIntegerArrayAttr(Attribute attr)
Extract integer values from the assumed ArrayAttr of IntegerAttr.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...