42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/ScopeExit.h"
44#include "llvm/ADT/SmallPtrSet.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/DebugLog.h"
48#include "llvm/Support/LogicalResult.h"
55#define DEBUG_TYPE "linalg-transforms"
62template <
typename PatternTy,
typename... Args>
65 using OpTy =
typename llvm::function_traits<
66 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
67 auto op = dyn_cast<OpTy>(operation);
72 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
77 auto result = pattern.returningMatchAndRewrite(op, rewriter);
80 return cast<LinalgOp>(
result->getOperation());
90 if (
auto attr = dyn_cast<Attribute>(ofr)) {
91 if (!isa<IntegerAttr>(attr))
92 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
97 Value transformValue = cast<Value>(ofr);
98 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
100 if (params.size() != 1)
101 return transformOp.emitDefiniteFailure()
102 <<
"requires exactly one parameter associated";
103 result.push_back(params[0]);
108 if (!llvm::hasSingleElement(payloadOps)) {
110 transformOp.emitSilenceableError()
111 <<
"handle must be mapped to exactly one payload op";
113 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
120 transformOp.emitSilenceableError()
121 <<
"payload op must have exactly 1 index result";
141 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
143 for (
auto param : params) {
144 if (!isa<IntegerAttr>(param))
145 return transformOp.emitDefiniteFailure()
146 <<
"expected the parameter to be associated with an integer "
154 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
156 transformOp.emitSilenceableError()
157 <<
"payload op must have exactly 1 index result";
158 diag.attachNote(op->getLoc())
159 <<
"has " << op->getNumResults() <<
" results";
162 result.push_back(op->getResult(0));
176 if (
auto attr = dyn_cast<Attribute>(paramOrHandle)) {
177 reified.push_back(cast<IntegerAttr>(attr).getInt());
180 if (isa<TransformParamTypeInterface>(
181 cast<Value>(paramOrHandle).
getType())) {
183 if (params.size() != 1)
184 return transformOp.emitSilenceableError() <<
"expected a single param";
186 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
190 Value handle = cast<Value>(paramOrHandle);
191 if (!isa<TransformHandleTypeInterface>(handle.getType()))
192 return transformOp.emitSilenceableError() <<
"unexpected value handle";
194 if (!llvm::hasSingleElement(payload))
195 return transformOp.emitSilenceableError()
196 <<
"requires param or handle that is mapped to 1 payload op";
198 Operation *paramOrHandlePayloadOp = *payload.begin();
201 return transformOp.emitSilenceableError()
202 <<
"requires param or handle to be result of op with 1 index "
208 return transformOp.emitSilenceableError()
209 <<
"requires param or handle to be the result of a constant like "
212 reified.push_back(attr.getInt());
221void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
226void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
231void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
236void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
242void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
245 options.rankReductionStrategy =
250void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
255void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
260void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
265void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
270void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
275void transform::ApplyDataLayoutPropagationPatternsOp::populatePatterns(
284void transform::ApplyExtractSliceSinkingPatternsOp::populatePatterns(
288 Operation *producer = opOperand->get().getDefiningOp();
289 Operation *consumer = opOperand->getOwner();
304 SmallVector<Operation *> getNewOps()
const {
305 return SmallVector<Operation *>(newOps.begin(), newOps.end());
309 void notifyOperationInserted(Operation *op,
310 OpBuilder::InsertPoint previous)
override {
311 ForwardingListener::notifyOperationInserted(op, previous);
313 if (previous.
isSet())
317 assert(
inserted.second &&
"expected newly created op");
320 void notifyOperationErased(Operation *op)
override {
321 ForwardingListener::notifyOperationErased(op);
322 op->
walk([&](Operation *op) { newOps.erase(op); });
334 llvm::scope_exit resetListener(
335 [&]() { rewriter.
setListener(previousListener); });
336 NewOpsListener newOpsListener(previousListener);
340 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
341 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
342 MaterializeInDestination;
343 }
else if (getMemcpyOp() ==
"memref.copy") {
346 }
else if (getMemcpyOp() ==
"linalg.copy") {
350 llvm_unreachable(
"invalid memcpy op");
352 if (getAllocOp() ==
"memref.alloc") {
355 }
else if (getAllocOp() ==
"memref.alloca") {
359 llvm_unreachable(
"invalid alloc op");
361 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
362 options.emitDealloc = getEmitDealloc();
366 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
373 <<
"failed to bufferize operation";
374 diag.attachNote(op->
getLoc()) <<
"target payload op";
377 allocatedBuffers.push_back(buffer);
381 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
382 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
386void transform::BufferizeToAllocationOp::getEffects(
388 if (getBufferizeDestinationOnly()) {
399LogicalResult transform::BufferizeToAllocationOp::verify() {
400 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
401 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
403 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
416 auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.
getOwner());
423 Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
431 if (!isa<TensorType, FloatType, IntegerType>(value.
getType()))
433 return llvm::any_of(value.
getUses(),
443 auto type = dyn_cast<RankedTensorType>(
tensor.getType());
445 return emitSilenceableError() <<
"non-tensor type: " <<
tensor;
459 for (
auto [pos, dim] : llvm::enumerate(type.getShape())) {
460 if (!ShapedType::isDynamic(dim))
465 tensor::DimOp::create(rewriter,
tensor.getLoc(),
tensor, cst);
466 preservedOps.insert(dimOp);
467 dynamicDims.push_back(dimOp);
469 auto allocation = bufferization::AllocTensorOp::create(
470 rewriter,
tensor.getLoc(), type, dynamicDims);
472 if (getMemorySpaceAttr())
473 allocation.setMemorySpaceAttr(getMemorySpaceAttr());
474 Value allocated = allocation;
478 if (needsMaterialization) {
479 auto copy = bufferization::MaterializeInDestinationOp::create(
481 preservedOps.insert(
copy);
482 promoted.push_back(
copy.getResult());
484 promoted.push_back(allocated);
488 results.
setValues(cast<OpResult>(getPromoted()), promoted);
492void transform::PromoteTensorOp::getEffects(
508#define DOWNSCALE(trans) \
510 FailureOr<LinalgOp> res = tryApply<trans>(target); \
511 if (succeeded(res)) { \
512 results.push_back(*res); \
513 return DiagnosedSilenceableFailure::success(); \
517#define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
518#define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
531#undef DOWNSCALE_NORMAL
534 return emitDefaultSilenceableFailure(
target);
548 auto decomposableOp = dyn_cast<AggregatedOpInterface>(
target);
549 if (!decomposableOp) {
551 "payload is not a decomposable op"));
552 return emitDefaultSilenceableFailure(
target);
555 FailureOr<SmallVector<Value>> maybeNewResults =
556 decomposableOp.decomposeOperation(rewriter);
557 if (
failed(maybeNewResults))
558 return emitDefaultSilenceableFailure(
target);
560 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
561 for (
Value val : *maybeNewResults) {
562 Operation *definition = val.getDefiningOp();
573void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
580transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
584 options.allowReturnAllocsFromLoops =
true;
590 <<
"failed to analyze op";
592 rewriter,
target, state)))
594 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
607 bool applyCleanup,
bool useForall) {
609 builder,
result, loopTypes,
615 applyCleanup, useForall);
621 bool applyCleanup,
bool useForall) {
629 applyCleanup, useForall);
636 bool applyCleanup,
bool useForall) {
640 build(builder,
result, loopTypes,
target, mixedTileSizes,
641 mixedTileInterchange, applyCleanup, useForall);
648 bool applyCleanup,
bool useForall) {
655 staticTileInterchange);
660 auto staticTileInterchangeAttr =
662 unsigned numExpectedLoops =
663 useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
665 resultTypes.reserve(numExpectedLoops);
666 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
667 "expected one loop type or as many as loops");
668 if (loopTypes.size() == 1)
669 resultTypes.append(numExpectedLoops, loopTypes[0]);
671 llvm::append_range(resultTypes, loopTypes);
676 dynamicTileInterchange,
678 staticTileInterchangeAttr,
685template <
typename Range>
689 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
695 auto tilingInterfaceOp = dyn_cast<TilingInterface>(
target);
696 if (!tilingInterfaceOp)
697 return transformOp->
emitError(
"only TilingInterface ops are supported");
700 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
701 applyFn(tilingInterfaceOp);
702 if (failed(tiledResults))
707 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
708 for (
Operation *toReplace : opsToReplace) {
709 for (
OpResult res : toReplace->getResults())
710 if (
auto replacement = tiledResults->replacements.lookup(res))
712 if (toReplace->use_empty()) {
718 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
719 assert(tiledResults->loops.size() == numLoops &&
720 "Mismatched number of loops, tile and fuse transform should have "
722 for (
unsigned int i = 0; i < numLoops; ++i)
723 loopOps[i].
push_back(tiledResults->loops[i]);
726 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
727 for (
unsigned int i = 0; i < numLoops; ++i)
728 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
737 auto transformOp = cast<TransformOpInterface>(getOperation());
741 state, transformOp, getMixedTileSizes(), tileSizes);
746 state, transformOp, getMixedTileInterchange(), tileInterchange);
750 scf::SCFTilingOptions tilingOptions;
751 tilingOptions.interchangeVector = tileInterchange;
752 bool useForall = getUseForall();
753 tilingOptions.setLoopType(useForall
754 ? scf::SCFTilingOptions::LoopType::ForallOp
755 : scf::SCFTilingOptions::LoopType::ForOp);
758 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
759 scf::SCFTileAndFuseOptions tileAndFuseOptions;
760 tileAndFuseOptions.tilingOptions = tilingOptions;
762 if (getApplyCleanup()) {
765 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
768 tileAndFuseOptions.cleanupPatterns = std::move(patterns);
772 useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
774 rewriter, getOperation(), state.
getPayloadOps(getTarget()), numLoops,
776 [&](TilingInterface tilingInterfaceOp)
777 -> FailureOr<scf::SCFTileAndFuseResult> {
778 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
785LogicalResult transform::FuseOp::verify() {
786 auto iterspace_rank = getStaticTileSizes().size();
788 if (permutation.size() > iterspace_rank)
790 <<
"interchange length exceeds iteration space dimensions ("
791 << iterspace_rank <<
"), found " << getTileInterchange();
793 for (
int64_t v : permutation) {
794 if (!ShapedType::isDynamic(v)) {
795 if (v < 0 || v >=
static_cast<int64_t>(iterspace_rank))
796 return emitOpError() <<
"expects interchange values to be in range [0, "
797 << iterspace_rank <<
"), found: " << v;
799 return emitOpError() <<
"found duplicate interchange value: " << v;
805 size_t numExpectedLoops =
806 getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
807 if (numExpectedLoops != getNumResults() - 1)
808 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
818 return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
822void transform::FuseOp::getEffects(
835void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
838 Value containingOp) {
839 result.addOperands({producerOp, containingOp});
840 auto resultType = transform::AnyOpType::get(builder.
getContext());
841 result.addTypes({resultType, resultType});
857 (domInfo.
dominates(containingOp, user))) {
858 dominatedUsers.insert(user);
861 if (dominatedUsers.empty())
865 auto forallOp = cast<scf::ForallOp>(containingOp);
871 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
876 newOuts.push_back(outputs[resultNumber]);
879 auto newforallOp = scf::ForallOp::create(
880 rewriter, loc, forallOp.getMixedLowerBound(),
881 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
882 forallOp.getMapping());
884 newforallOp.getRegion().takeBody(forallOp.getRegion());
889 newforallOp.getBody()->addArgument(newOuts.back().getType(),
890 newOuts.back().getLoc());
891 auto bbArgs = newforallOp.getBody()->getArguments();
894 Operation *op = use.getOwner();
895 return newforallOp->isProperAncestor(op);
899 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
901 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; });
902 Operation *firstYieldOp = yieldingOps.front();
905 Value dst = newforallOp.getRegionIterArgs().back();
907 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->
getLoc(), src,
908 dst, offsets, sizes, strides);
910 for (
auto result : llvm::enumerate(forallOp.getResults())) {
912 newforallOp->getResult(
result.index()));
915 newforallOp->getResults().back(),
917 Operation *user = use.getOwner();
918 return dominatedUsers.contains(user);
932 destWorklist.push_back(dst);
934 while (!destWorklist.empty()) {
935 Value currentDst = destWorklist.pop_back_val();
939 if (src == currentDst)
944 auto bbArg = dyn_cast<BlockArgument>(currentDst);
948 Block *parentBlock = bbArg.getOwner();
949 assert(parentBlock &&
"unlinked block argument");
952 assert(parentOp &&
"expected block argument with parent operation");
955 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
959 for (
auto innerIterArg : parentLoop.getRegionIterArgs()) {
961 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
962 Value loopBlockArgument =
964 destWorklist.push_back(loopBlockArgument);
977static std::tuple<SmallVector<Operation *>,
Operation *>
980 LDBG() <<
"Try to fuse a direct extract use";
981 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
982 if (!tileableProducer) {
984 <<
"producer is not a TileableInterface: " << *producerOp;
991 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
992 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
993 return sliceOp && containingOp->isProperAncestor(sliceOp);
997 if (it == tileableProducer->getUsers().end()) {
998 diag.attachNote(tileableProducer->getLoc())
999 <<
"could not find fusion opportunity for: " << *tileableProducer;
1002 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
1015 if (LoopLikeOpInterface containerLoop =
1016 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
1022 auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(
clone);
1026 for (
OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
1027 Value producerOperand =
1028 clone->getOperand(initOperandPtr.getOperandNumber());
1030 containerLoop.getRegionIterArgs()) {
1031 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
1032 Value consumerOperand =
1036 initOperandPtr.set(containerIterArg);
1042 tileableProducer = dyn_cast<TilingInterface>(
clone);
1047 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
1048 LDBG() <<
"resultNumber: " << resultNumber;
1053 FailureOr<TilingResult> tileAndFuseResult =
1054 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
1057 if (failed(tileAndFuseResult)) {
1058 diag.attachNote(tileableProducer->getLoc())
1059 <<
"failed to tile producer op: " << *tileableProducer;
1064 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
1065 LDBG() <<
"tiledProducer: " << *tiledOp;
1070 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1071 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1072 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1073 if (failed(maybeRankReduced)) {
1075 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
1076 << tileAndFuseResult->tiledValues[0]
1077 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
1080 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
1084 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
1085 resultNumber, offsets, sizes);
1088 if (isa<LoopLikeOpInterface>(containingOp))
1089 rewriter.
eraseOp(tileableProducer);
1091 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
1104 LDBG() <<
"Try to fuse an extract use through block argument";
1106 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
1107 if (!tileableProducer) {
1109 <<
"producer is not a TileableInterface: " << *producerOp;
1114 scf::ForallOp forallOp;
1115 auto itProducerUses =
1116 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
1117 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
1121 if (!forallOp || forallOp != containingOp) {
1122 diag.attachNote(tileableProducer->getLoc())
1123 <<
"could not find a use by the containing op: " << *tileableProducer;
1138 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1139 return sliceOp && containingOp->isProperAncestor(sliceOp);
1143 if (itBBArgUsers == bbArg.
getUsers().end()) {
1145 <<
"could not find fusion opportunity for bbArg: " << bbArg;
1148 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
1156 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
1157 LDBG() <<
"resultNumber: " << resultNumber;
1162 rewriter, tileableProducer->getLoc(), tileableProducer,
1163 destinationTensors))) {
1164 diag.attachNote(tileableProducer->getLoc())
1165 <<
"failed to get destination tensors for: " << *tileableProducer;
1170 bvm.
map(destinationTensors[resultNumber], bbArg);
1171 auto tileableProducerClone =
1172 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
1173 llvm::scope_exit scopeGuard(
1174 [&]() { rewriter.
eraseOp(tileableProducerClone); });
1177 FailureOr<TilingResult> tileAndFuseResult =
1178 tileableProducerClone.generateResultTileValue(
1179 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
1180 sliceOpToTile.getMixedSizes());
1181 if (failed(tileAndFuseResult)) {
1182 diag.attachNote(tileableProducer->getLoc())
1183 <<
"failed to tile producer op: " << *tileableProducer;
1188 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1189 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1190 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1191 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
1192 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
1197 destinationTensors.front());
1200 return tileAndFuseResult->tiledOps;
1206 LDBG() <<
"Try to fuse an use by cloning";
1213 uses.push_back(&use);
1218 if (containingOp == use.getOwner()) {
1220 <<
"producer op use by containing op cannot be fused by cloning";
1228 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
1237 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
1238 "Parallel insert slice is not a valid clone destination");
1239 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
1240 LDBG() <<
"resultNumber: " << resultNumber;
1244 fusedOp = rewriter.
clone(*producerOp);
1246 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1251bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1262 auto containingOps = state.
getPayloadOps(getContainingOp());
1263 if (!llvm::hasSingleElement(containingOps)) {
1265 <<
"requires exactly one containing_op handle (got "
1266 << llvm::range_size(containingOps) <<
")";
1268 Operation *containingOp = *containingOps.begin();
1271 if (std::empty(producerOps)) {
1273 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1280 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1281 for (
const auto &it :
enumerate(remainingProducers)) {
1284 int64_t numUsesInContainingOp =
1286 return containingOp->isAncestor(op);
1291 if (numUsesInContainingOp > 0) {
1292 if (numUsesInContainingOp == 1)
1293 remainingProducers.erase(remainingProducers.begin() + it.index());
1300 while (!remainingProducers.empty()) {
1301 auto nextProducer = getNextProducer();
1302 if (
failed(nextProducer)) {
1304 <<
"could not find next producer to fuse into container";
1305 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
1313 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
1320 auto [tiledOps, newContainingOp] =
1322 if (!tiledOps.empty()) {
1323 LDBG() <<
"\nFused a direct extract use\n" << *containingOp;
1324 fusedOps.append(tiledOps);
1325 if (newContainingOp) {
1333 LogicalResult replacementStatus =
1336 (
void)replacementStatus;
1337 assert(succeeded(replacementStatus) &&
1338 "unable to update transform state mapping");
1339 rewriter.
eraseOp(containingOp);
1340 containingOp = newContainingOp;
1347 rewriter,
diag, producerOp, containingOp);
1348 if (!tiledContainingOpOperand.empty()) {
1349 LDBG() <<
"\nFused an extract use through block argument\n"
1351 fusedOps.append(tiledContainingOpOperand);
1358 LDBG() <<
"\nFused an use by cloning\n" << *containingOp;
1359 fusedOps.push_back(cloned);
1365 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1366 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1370void transform::FuseIntoContainingOp::getEffects(
1388 if (isa<GenericOp>(
target)) {
1394 if (succeeded(generic)) {
1395 results.
push_back(generic->getOperation());
1398 return emitDefaultSilenceableFailure(
target);
1411 if (!isa<GenericOp>(
target)) {
1416 FailureOr<LinalgOp> named =
1418 if (succeeded(named)) {
1419 results.
push_back(named->getOperation());
1422 return emitDefaultSilenceableFailure(
target);
1436 if (interchangeVector.empty()) {
1441 unsigned numLoops = cast<LinalgOp>(
target.getOperation()).getNumLoops();
1442 if (interchangeVector.size() != numLoops) {
1443 return emitSilenceableError()
1444 << getIteratorInterchangeAttrName() <<
" has length ("
1445 << interchangeVector.size()
1446 <<
") different from the number of loops in the target operation ("
1457LogicalResult transform::InterchangeOp::verify() {
1459 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1460 if (!std::is_permutation(sequence.begin(), sequence.end(),
1461 permutation.begin(), permutation.end())) {
1463 <<
"expects iterator_interchange to be a permutation, found "
1464 << getIteratorInterchange();
1479 if (!isa<linalg::CopyOp>(targetOp)) {
1481 emitSilenceableError() <<
"only linalg.copy target ops are supported";
1482 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1486 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1487 if (!copyOp.hasPureBufferSemantics()) {
1489 emitSilenceableError()
1490 <<
"cannot transform a linalg.copy on tensors into a memref.copy";
1491 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1497 assert(inputs.size() == 1 &&
"expected linalg copy op with one input");
1498 assert(outputs.size() == 1 &&
"expected memref copy op with one output");
1499 Value input = inputs.front();
1500 Value output = outputs.front();
1505 if (!isa<ShapedType>(input.
getType())) {
1507 emitSilenceableError()
1508 <<
"cannot transform a linalg.copy which input has no shape";
1509 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1514 assert(isa<ShapedType>(output.
getType()));
1516 if (cast<ShapedType>(input.
getType()).getElementType() !=
1517 cast<ShapedType>(output.
getType()).getElementType()) {
1519 emitSilenceableError()
1520 <<
"cannot transform a linalg.copy with different source and "
1521 "destination element types ";
1522 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1543 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1544 FailureOr<LowerPackResult> res =
1548 <<
"cannot lower to pad + expand + transpose";
1551 transformResults.
push_back(res->expandShapeOp);
1552 transformResults.
push_back(res->transposeOp);
1565 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1566 FailureOr<LowerUnPackOpResult> res =
1570 emitSilenceableError()
1571 <<
"cannot lower to transpose + collapse + extract";
1572 diag.attachNote(
target->getLoc()) <<
"target payload op";
1575 transformResults.
push_back(res->emptyOp);
1576 transformResults.
push_back(res->transposeOp);
1577 transformResults.
push_back(res->collapseShapeOp);
1578 transformResults.
push_back(res->extractSliceOp);
1579 transformResults.
push_back(res->copyOp);
1590 result.addAttribute(MatchOp::getOpsAttrName(
result.name),
1599 result.addAttribute(MatchOp::getOpsAttrName(
result.name),
1601 result.addTypes(resultTypes);
1609 if (getOps().has_value())
1610 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1613 if (!llvm::hasSingleElement(payloadOps)) {
1618 bool incorrectNumOperandTypes =
false;
1625 if (getInterface().has_value()) {
1626 auto iface = getInterface().value();
1627 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1630 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1631 !isa<TilingInterface>(op))
1633 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1634 !isa<LoopLikeOpInterface>(op))
1639 if (getOpAttrs().has_value()) {
1640 DictionaryAttr opAttrs = getOpAttrs().value();
1642 if (attr.getName() == getInterfaceAttrName() ||
1643 attr.getName() == getOpsAttrName())
1645 if (!op->
hasAttr(attr.getName()))
1647 if (op->
getAttr(attr.getName()) != attr.getValue())
1652 if (getFilterResultType().has_value()) {
1653 Type t = getFilterResultType().value();
1658 if (getFilterOperandTypes().has_value()) {
1659 mlir::ArrayAttr types = getFilterOperandTypes().value();
1662 if (types.size() == 1) {
1665 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1666 Type t = cast<::mlir::Type>(typeattr.getValue());
1668 [&](
Type operandType) { return operandType == t; }))
1673 if (types.size() != operandTypes.size()) {
1674 incorrectNumOperandTypes =
true;
1678 for (
auto [attr, operandType] :
1679 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1680 auto typeattr = cast<mlir::TypeAttr>(attr);
1681 Type type = cast<::mlir::Type>(typeattr.getValue());
1683 if (type != operandType)
1694 (*payloadOps.begin())->walk(matchFun);
1695 if (incorrectNumOperandTypes)
1697 "type, then it must contain as much types as "
1698 "the number of operands in the target ops");
1699 results.
set(cast<OpResult>(getResult()), res);
1714 Type &targetType,
Type &lowSizeType,
1716 Type &splitPointType) {
1717 FunctionType funcType;
1719 if (failed(parser.
parseType<FunctionType>(funcType)))
1722 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1723 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1724 "argument and one result";
1726 targetType = funcType.getInput(0);
1727 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1735 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1736 if (
target.hasDynamicShape()) {
1737 auto diag = emitSilenceableError()
1738 <<
"cannot compute parametric tile sizes for dynamically "
1739 "shaped payload op";
1740 diag.attachNote(
target->getLoc()) <<
"payload op";
1745 target, getDimension(), getTargetSize(), getDivisor());
1747 return emitSilenceableError()
1748 <<
"failed to compute multi-size tiling sizes";
1752 results.
assign(llvm::map_range(
1754 spec->lowTileSize * spec->lowTripCount}),
1755 [&builder,
this](
int64_t value) {
1767 builder,
target, getDimension(), targetSize, divisor);
1769 return emitSilenceableError() <<
"could not generate tile size computation";
1776 {spec->lowTileSize, spec->lowTripCount});
1777 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1778 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1779 assert(lowTileSize && highTileSize && splitPoint &&
1780 "tile sizes are not produced by operations");
1788void transform::MultiTileSizesOp::getEffects(
1792 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1798LogicalResult transform::MultiTileSizesOp::verify() {
1801 return emitOpError() <<
"expects all results type to be the same";
1820 Type linalgOpHType = transform::OperationType::get(
1821 builder.
getContext(), GenericOp::getOperationName());
1840 if (std::empty(targetOps)) {
1841 transformResults.
set(cast<OpResult>(getPackedOp()),
1846 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1847 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1848 return emitSilenceableError()
1849 <<
"requires target to map to exactly 1 LinalgOp (got "
1850 << llvm::range_size(targetOps) <<
")";
1853 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1854 return emitSilenceableError()
1855 <<
"requires number of packed sizes match the number of loops ("
1856 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1863 state, *
this, packedSizes, getMixedPackedSizes());
1866 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1870 transformResults.
set(cast<OpResult>(getPackedOp()),
1871 {maybeResult->packedLinalgOp.getOperation()});
1875void transform::PackOp::getEffects(
1887LogicalResult transform::PackGreedilyOp::verify() {
1889 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1890 <<
" is not a valid permutation";
1893 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1894 for (
auto [s, nmo] :
1895 llvm::zip_equal(getMixedMatmulPackedSizes(),
1896 getMatmulPaddedSizesNextMultipleOf())) {
1899 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1900 return emitOpError() <<
"at most one of the packed_size and the "
1901 "padded_sizes_next_multiple_of can be nonzero "
1902 "for the matmul strategy";
1915 auto linalgOp = dyn_cast<LinalgOp>(op);
1926 getMixedMatmulPackedSizes(),
1928 getMatmulPaddedSizesNextMultipleOf(),
1929 getMatmulInnerDimsOrder());
1930 if (succeeded(packResult)) {
1931 results.push_back(packResult->packedLinalgOp);
1934 results.push_back(linalgOp);
1936 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1942 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1946void transform::PackGreedilyOp::getEffects(
1958LogicalResult transform::PackTransposeOp::verify() {
1961 <<
" is not a valid permutation";
1965 <<
" is not a valid permutation";
1967 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1968 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1969 <<
" or " << getOuterPermAttrName()
1970 <<
" must be specified";
1976enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1986template <
typename RelayoutOpTy>
1987static bool isValidPackingPermutation(
1989 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1991 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1992 "applies to only pack or unpack operations");
1993 if (!op || permutation.empty())
1995 size_t innerRank = op.getInnerDimsPos().size();
1996 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
2000 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
2001 return permutation.size() == op.getSourceRank() &&
2004 return permutation.size() == op.getDestRank() &&
2012 auto packOrUnpackOps = state.
getPayloadOps(getTargetPackOrUnPackOp());
2015 if (std::empty(packOrUnpackOps)) {
2016 transformResults.
set(cast<OpResult>(getPackedOp()), {});
2017 transformResults.
set(cast<OpResult>(getPackOp()), {});
2018 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
2024 if (!llvm::hasSingleElement(packOrUnpackOps) ||
2025 !llvm::hasSingleElement(linalgOps)) {
2026 return emitSilenceableError()
2027 <<
"requires target to map to exactly 1 "
2028 "packing op and 1 packed op ("
2029 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
2030 << llvm::range_size(linalgOps) <<
")";
2034 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
2035 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
2036 if ((!packOp && !unPackOp)) {
2037 return emitSilenceableError() <<
"requires target to map to a "
2038 "linalg.pack or linalg.unpack";
2040 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
2041 if (!linalgOpTarget)
2042 return emitSilenceableError() <<
"requires a LinalgOp target";
2046 if (packOp && packOp.getResult().hasOneUse())
2047 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
2049 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
2050 if (linalgOp != linalgOpTarget) {
2052 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
2053 : StringLiteral{
"not produced by the LinalgOp target"};
2054 return emitSilenceableError() << errorMsg;
2060 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
2061 OpOperand *packUse = linalgOp.getDpsInitOperand(
2062 cast<OpResult>(unPackOp.getSource()).getResultNumber());
2064 if (!packOp || !packOp.getResult().hasOneUse())
2065 return emitSilenceableError() <<
"could not find matching pack op";
2069 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
2071 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
2072 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
2073 ? StringLiteral{
"invalid outer_perm"}
2074 : StringLiteral{
"invalid inner_perm"};
2075 if (!isValidPackingPermutation(packOp, perm, permType) ||
2076 !isValidPackingPermutation(unPackOp, perm, permType)) {
2078 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
2079 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
2085 assert(packOp && linalgOp &&
"unexpected null op");
2089 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
2091 assert(succeeded(res) &&
"unexpected packTranspose failure");
2094 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
2095 transformResults.
set(cast<OpResult>(getPackedOp()),
2096 {res->transposedLinalgOp});
2098 transformResults.
set(cast<OpResult>(getUnPackOp()),
2099 {res->transposedUnPackOp});
2101 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
2116 StringRef copyBackOp,
2117 bool usePrescribedTensorShapes) {
2118 auto resultType = transform::AnyOpType::get(
b.getContext());
2124 b.getI64ArrayAttr(paddingDimensions),
2127 (padToMultipleOf.empty()
2129 :
b.getDenseI64ArrayAttr(padToMultipleOf)),
2130 b.getI64ArrayAttr(nofoldFlags),
2131 b.getArrayAttr(transposePaddings),
2132 b.getStringAttr(copyBackOp),
2134 usePrescribedTensorShapes ?
b.getUnitAttr() :
nullptr);
2142 StringRef copyBackOp,
2143 bool usePrescribedTensorShapes) {
2144 auto resultType = transform::AnyOpType::get(
b.getContext());
2148 staticPadToMultipleOf);
2154 b.getI64ArrayAttr(paddingDimensions),
2155 dynamicPadToMultipleOf,
2156 staticPadToMultipleOf,
2157 b.getI64ArrayAttr(nofoldFlags),
2158 b.getArrayAttr(transposePaddings),
2160 usePrescribedTensorShapes);
2163void PadOp::getEffects(
2171SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
2173 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(),
b);
2176DiagnosedSilenceableFailure
2177transform::PadOp::apply(transform::TransformRewriter &rewriter,
2178 transform::TransformResults &results,
2179 transform::TransformState &state) {
2180 auto transformOp = cast<TransformOpInterface>(getOperation());
2181 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
2184 auto linalgTarget = dyn_cast<LinalgOp>(
target);
2185 if (!linalgTarget) {
2186 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
2187 diag.attachNote(
target->getLoc()) <<
"target op";
2192 SmallVector<bool> nofoldFlags;
2193 for (int64_t packPadding :
2195 nofoldFlags.push_back(
static_cast<bool>(packPadding));
2198 SmallVector<Attribute> paddingValues;
2199 for (
auto const &[untypedAttr, elementOrTensorType] :
2200 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
2203 paddingValues.push_back(untypedAttr);
2206 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2208 emitOpError(
"expects padding values to be typed attributes or poison");
2213 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2217 if (!parsedAttr || parsedAttr.getType() != elementType) {
2219 << elementType <<
", got " << untypedAttr;
2220 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2223 paddingValues.push_back(parsedAttr);
2227 if (attr.getType() != elementType) {
2229 << elementType <<
", got " << attr;
2230 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2233 paddingValues.push_back(attr);
2237 SmallVector<SmallVector<int64_t>> transposePaddings;
2238 for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2240 cast<ArrayAttr>(transposeVector)));
2247 SmallVector<int64_t> padToMultipleOf;
2249 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2252 if (padToMultipleOf.empty())
2254 SmallVector<int64_t>(
options.paddingDimensions.size(), 1);
2256 options.padToMultipleOf = padToMultipleOf;
2257 options.paddingValues = paddingValues;
2258 options.nofoldFlags = nofoldFlags;
2259 if (getCopyBackOp() ==
2260 bufferization::MaterializeInDestinationOp::getOperationName()) {
2261 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
2262 BufferizationMaterializeInDestination;
2263 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2264 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
2265 }
else if (getCopyBackOp() == kCopyOpNone) {
2266 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
2268 llvm_unreachable(
"unsupported copy_back op");
2271 bool irChanged =
false;
2272 if (getUsePrescribedTensorShapes() &&
2273 linalgTarget.hasPureTensorSemantics()) {
2274 OpBuilder::InsertionGuard g(rewriter);
2276 for (OpOperand &operand : linalgTarget->getOpOperands()) {
2277 for (
auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2278 if (ShapedType::isStatic(dim))
2280 options.setSizeToPadTo(operand.getOperandNumber(), i,
2282 operand.get().getLoc(),
2289 SmallVector<Value> replacements;
2290 SmallVector<tensor::PadOp> newPadOps;
2292 replacements, newPadOps))) {
2298 auto diag = emitSilenceableError() <<
"failed to pad op";
2299 diag.attachNote(
target->getLoc()) <<
"target op";
2308 rewriter.
replaceOp(linalgTarget, replacements);
2309 paddedOps.push_back(paddedOp);
2310 padOps.append(newPadOps.begin(), newPadOps.end());
2311 if (
options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2312 for (Value v : replacements) {
2313 Operation *copyBackOp = v.getDefiningOp();
2314 if (!llvm::is_contained(copyBackOps, copyBackOp))
2315 copyBackOps.push_back(copyBackOp);
2320 results.
set(cast<OpResult>(getPadded()), paddedOps);
2321 results.
set(cast<OpResult>(getPad()), padOps);
2322 results.
set(cast<OpResult>(getCopy()), copyBackOps);
2326LogicalResult transform::PadOp::verify() {
2327 SmallVector<int64_t> nofoldFlags =
2329 if (any_of(nofoldFlags, [](int64_t packPadding) {
2330 return packPadding != 0 && packPadding != 1;
2333 <<
"expects nofold_flags to contain booleans (0/1), found "
2334 << getNofoldFlags();
2337 SmallVector<int64_t> paddingDimensions =
2339 if (any_of(paddingDimensions,
2340 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
2341 return emitOpError() <<
"expects padding_dimensions to contain positive "
2343 << getPaddingDimensions();
2345 if (!getMixedPadToMultipleOf().empty()) {
2346 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2347 return emitOpError() <<
"expects as many multiples as padding_dimensions";
2350 ArrayAttr transposes = getTransposePaddings();
2351 for (Attribute attr : transposes) {
2353 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2354 if (!std::is_permutation(sequence.begin(), sequence.end(),
2355 transpose.begin(), transpose.end())) {
2357 <<
"expects transpose_paddings to be a permutation, found "
2361 if (getCopyBackOp() !=
2362 bufferization::MaterializeInDestinationOp::getOperationName() &&
2363 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2364 getCopyBackOp() != kCopyOpNone)
2373void transform::PadTilingInterfaceOp::build(OpBuilder &
b,
2376 ArrayRef<int64_t> paddingSizes,
2377 bool padToMultipleOf) {
2378 auto resultType = transform::AnyOpType::get(
b.getContext());
2387 :
b.getDenseI64ArrayAttr(paddingSizes)),
2389 padToMultipleOf ?
b.getUnitAttr() :
nullptr);
2392void transform::PadTilingInterfaceOp::build(
2394 ArrayRef<OpFoldResult> mixedPaddingSizes,
bool padToMultipleOf) {
2395 auto resultType = transform::AnyOpType::get(
b.getContext());
2396 SmallVector<int64_t> staticPaddingSizes;
2397 SmallVector<Value> dynamicPaddingSizes;
2399 staticPaddingSizes);
2405 dynamicPaddingSizes,
2410void transform::PadTilingInterfaceOp::getEffects(
2411 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2418SmallVector<OpFoldResult>
2419transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2424DiagnosedSilenceableFailure
2425transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2426 transform::TransformResults &results,
2427 transform::TransformState &state) {
2428 SmallVector<Operation *> paddedOps, padOps;
2431 auto targetOp = dyn_cast<TilingInterface>(
target);
2433 auto diag = emitSilenceableError() <<
"expected TilingInterface target";
2434 diag.attachNote(
target->getLoc()) <<
"target op";
2441 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2442 auto diag = emitSilenceableError() <<
"only IndexingMapOpInterface ops "
2444 diag.attachNote(
target->getLoc()) <<
"target op";
2449 SmallVector<Attribute> paddingValues;
2450 for (
auto const &[untypedAttr, elementOrTensorType] :
2451 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2452 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2456 paddingValues.push_back(untypedAttr);
2460 emitOpError(
"expects padding values to be typed attributes or poison");
2464 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2468 if (!parsedAttr || parsedAttr.getType() != elementType) {
2470 << elementType <<
", got " << attr;
2471 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2474 paddingValues.push_back(parsedAttr);
2478 if (attr.getType() != elementType) {
2480 << elementType <<
", got " << attr;
2481 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2484 paddingValues.push_back(attr);
2488 PadTilingInterfaceOptions
options;
2489 options.setPaddingValues(paddingValues)
2490 .setPaddingSizes(getMixedPaddingSizes())
2491 .setPadToMultipleOf(getPadToMultipleOf());
2493 OpBuilder::InsertionGuard g(rewriter);
2496 rewriter, cast<TilingInterface>(targetOp.getOperation()),
options);
2497 if (
failed(maybePadOps)) {
2498 auto diag = emitSilenceableError() <<
"failed to pad op";
2499 diag.attachNote(
target->getLoc()) <<
"target op";
2502 const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
2505 paddedOps.push_back(paddedOp);
2506 padOps.append(paddedOperands.begin(), paddedOperands.end());
2507 rewriter.
replaceOp(targetOp.getOperation(), slicedResults);
2510 results.
set(cast<OpResult>(getPadded()), paddedOps);
2511 results.
set(cast<OpResult>(getPad()), padOps);
2515LogicalResult transform::PadTilingInterfaceOp::verify() {
return success(); }
2521DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2522 transform::TransformRewriter &rewriter,
2523 transform::TransformResults &transformResults,
2524 transform::TransformState &state) {
2527 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2529 <<
"requires exactly one target and one loop handle (got "
2530 << llvm::range_size(targetOps) <<
" and "
2531 << llvm::range_size(loopOps) <<
")";
2534 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2535 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2536 if (!padOp || !loopOp)
2539 FailureOr<linalg::detail::PackingResult>
result =
2545 if (
result->clonedLoopIvs.empty()) {
2546 transformResults.
set(cast<OpResult>(getPackingLoop()),
2547 {
result->hoistedPadOp.getOperation()});
2550 auto outerPackedLoop =
2552 transformResults.
set(cast<OpResult>(getPackingLoop()),
2553 {outerPackedLoop.getOperation()});
2557LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
2558 ArrayRef<int64_t> transpose = getTranspose();
2559 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2560 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2562 return emitOpError() <<
"expects transpose to be a permutation, found "
2568void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2569 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2576DiagnosedSilenceableFailure
2577transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2579 transform::ApplyToEachResultList &results,
2580 transform::TransformState &state) {
2581 tensor::PadOp hoistedPadOp;
2582 SmallVector<TransposeOp> transposeOps;
2583 FailureOr<Value>
result =
2585 hoistedPadOp, transposeOps);
2596 return emitDefaultSilenceableFailure(
target);
2599LogicalResult transform::HoistPadOp::verify() {
2600 ArrayRef<int64_t> transpose = getTranspose();
2601 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2602 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2604 return emitOpError() <<
"expects transpose to be a permutation, found "
2614DiagnosedSilenceableFailure
2615transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2617 transform::ApplyToEachResultList &results,
2618 transform::TransformState &state) {
2619 LinalgPromotionOptions promotionOptions;
2620 if (!getOperandsToPromote().empty())
2623 if (getUseFullTilesByDefault())
2625 getUseFullTilesByDefault());
2626 if (getUseOriginalSubviewSize())
2630 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2631 if (!getUseFullTileBuffers().empty())
2633 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2634 if (getAlignment().has_value())
2635 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2636 if (getMemorySpace().has_value())
2637 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2639 if (getMapping().has_value()) {
2641 auto mapping = *getMapping();
2642 if (mapping.size() > 1)
2643 return emitDefaultDefiniteFailure(
target);
2645 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2647 if (addressSpace.getAddressSpace() ==
2648 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2655 }
else if (addressSpace.getAddressSpace() ==
2656 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2664 return emitDefaultDefiniteFailure(
target);
2669 return emitDefaultDefiniteFailure(
target);
2674 return emitDefaultDefiniteFailure(
target);
2683DiagnosedSilenceableFailure
2684transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2685 TransformResults &transformResults,
2686 TransformState &state) {
2690 for (Operation *
target : payload) {
2691 if (
target->getNumOperands() > 0)
2693 if (!
target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2694 target->getNumRegions() > 0)
2696 <<
"expected target that is isolated from above";
2700 Operation *pattern = &getBodyRegion().front().front();
2701 SmallVector<Operation *> replacements;
2702 for (Operation *
target : payload) {
2703 if (getOperation()->isAncestor(
target))
2710 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2714void transform::ReplaceOp::getEffects(
2715 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2721LogicalResult transform::ReplaceOp::verify() {
2722 if (!getBodyRegion().hasOneBlock())
2724 if (std::distance(getBodyRegion().front().begin(),
2725 getBodyRegion().front().end()) != 1)
2726 return emitOpError() <<
"expected one operation in block";
2727 Operation *
replacement = &getBodyRegion().front().front();
2730 <<
"expected replacement without operands";
2731 if (!
replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2734 <<
"expect op that is isolated from above";
2742DiagnosedSilenceableFailure
2743transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2745 transform::ApplyToEachResultList &results,
2746 transform::TransformState &state) {
2747 scf::SCFTilingOptions tilingOptions;
2748 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &
b, Operation *) {
2749 SmallVector<OpFoldResult> tileSizes;
2750 Location loc =
target.getLoc();
2751 SmallVector<OpFoldResult> allShapeSizes =
2752 target.createFlatListOfOperandDims(
b, loc);
2753 AffineMap map =
target.getShapesToLoopsMap();
2756 SmallVector<OpFoldResult> shapeSizes =
2761 for (OpFoldResult shapeSize : shapeSizes) {
2763 :
b.getIndexAttr(1));
2768 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2769 rewriter, cast<TilingInterface>(
target.getOperation()), tilingOptions);
2770 if (
failed(maybeTilingResult))
2771 return emitDefaultDefiniteFailure(
target);
2773 if (
target->getNumResults())
2778 results.
reserve(maybeTilingResult->tiledOps.size());
2779 for (Operation *tiled : maybeTilingResult->tiledOps)
2788DiagnosedSilenceableFailure
2789transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2790 transform::TransformResults &results,
2791 transform::TransformState &state) {
2792 SmallVector<Operation *> loops;
2794 auto tilingOp = dyn_cast<TilingInterface>(*
target);
2796 DiagnosedSilenceableFailure
diag =
2797 emitSilenceableError()
2798 <<
"expected the payload to implement TilingInterface";
2799 diag.attachNote(
target->getLoc()) <<
"payload op";
2803 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2804 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2805 if (
failed(generatedLoops))
2806 return emitDefaultDefiniteFailure(
target);
2807 for (scf::ForOp &loop : *generatedLoops) {
2808 loops.push_back(loop.getOperation());
2812 results.
set(cast<OpResult>(getResult()), loops);
2820DiagnosedSilenceableFailure
2821transform::RewriteInDestinationPassingStyleOp::applyToOne(
2822 transform::TransformRewriter &rewriter, Operation *
target,
2823 transform::ApplyToEachResultList &results,
2824 transform::TransformState &state) {
2826 FailureOr<Operation *> maybeResult =
2828 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2829 [&rewriter](
auto op) {
2833 return emitDefaultSilenceableFailure(
target);
2842DiagnosedSilenceableFailure
2843SplitOp::apply(transform::TransformRewriter &rewriter,
2844 TransformResults &results, TransformState &state) {
2846 SmallVector<Operation *> payload =
2849 bool isMultiwaySplit = getMultiway();
2851 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2853 <<
"requires exactly one target when "
2854 "multiway split is enabled (got "
2855 << llvm::range_size(payload) <<
")";
2858 SmallVector<OpFoldResult> chunkSizes;
2860 if (!isMultiwaySplit)
2861 chunkSizes.reserve(payload.size());
2863 if (getDynamicChunkSizes()) {
2865 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2866 chunkSizes = llvm::map_to_vector(
2867 state.
getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2870 diag = emitSilenceableError()
2871 <<
"expected dynamic split point handle to point to a "
2872 "single-result index-typed op";
2873 diag.attachNote(op->
getLoc()) <<
"dynamic split point";
2878 chunkSizes = llvm::map_to_vector(
2879 state.
getParams(getDynamicChunkSizes()),
2880 [](Attribute attr) {
return OpFoldResult(attr); });
2882 if (
diag.isSilenceableFailure())
2887 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2889 <<
"expected the dynamic split point handle to point to as "
2891 << chunkSizes.size() <<
") as the target handle ("
2892 << payload.size() <<
")";
2895 chunkSizes.resize(payload.size(),
2899 auto checkStructuredOpAndDimensions =
2900 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2902 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2903 diag.attachNote(loc) <<
"target op";
2907 if (getDimension() >= linalgOp.getNumLoops()) {
2908 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2909 <<
" does not exist in target op";
2910 diag.attachNote(loc) <<
"target op";
2916 auto checkFailureInSplitting =
2917 [&](
bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2926 SmallVector<Operation *> opList;
2927 if (isMultiwaySplit) {
2930 TilingInterface head, tail;
2931 Operation *
target = payload.front();
2933 LinalgOp linalgOp = dyn_cast<LinalgOp>(
target);
2936 DiagnosedSilenceableFailure
diag =
2937 checkStructuredOpAndDimensions(linalgOp,
target->getLoc());
2938 if (
diag.isSilenceableFailure())
2941 for (
auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2944 target = tail.getOperation();
2949 linalgOp = cast<LinalgOp>(
target);
2950 Location loc =
target->getLoc();
2954 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2955 getDimension(), chunkSize);
2958 DiagnosedSilenceableFailure
diag =
2959 checkFailureInSplitting(!head && !tail, loc);
2960 if (
diag.isDefiniteFailure())
2963 opList.push_back(head.getOperation());
2968 opList.push_back(tail.getOperation());
2972 SmallVector<Operation *> first, second;
2973 Operation *noSecondPart =
nullptr;
2974 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2975 Operation *
target = std::get<0>(pair);
2976 Location loc =
target->getLoc();
2977 LinalgOp linalgOp = dyn_cast<LinalgOp>(
target);
2978 DiagnosedSilenceableFailure
diag =
2979 checkStructuredOpAndDimensions(linalgOp,
target->getLoc());
2981 if (
diag.isSilenceableFailure())
2985 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2986 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2987 getDimension(), std::get<1>(pair));
2990 DiagnosedSilenceableFailure diagSplit =
2991 checkFailureInSplitting(!first.back() && !second.back(), loc);
2996 if (!second.back()) {
3002 if (second.size() != first.size() && !second.empty()) {
3003 auto diag = emitSilenceableError()
3004 <<
"splitting does not produce the second part for a subset "
3007 <<
"expected splitting to produce the second part of all "
3008 "or none of the targets";
3010 <<
"first target with no second part";
3014 opList.append(first);
3015 if (!second.empty())
3016 opList.append(second);
3018 results.
set(cast<OpResult>(getSplitList()), opList);
3022void SplitOp::getEffects(
3023 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3025 if (getDynamicChunkSizes())
3031ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &
result) {
3032 OpAsmParser::UnresolvedOperand
target, dynamicChunkSizes;
3033 IntegerAttr staticChunkSizes;
3037 OptionalParseResult dynamicPointParseResult =
3039 if (!dynamicPointParseResult.
has_value()) {
3040 int64_t staticChunkSizesValue;
3054 if (dynamicPointParseResult.
has_value()) {
3055 Type chunkSizesType;
3068 SplitOp::getStaticChunkSizesAttrName(
result.name).getValue(),
3070 result.addTypes(targetType);
3074void SplitOp::print(OpAsmPrinter &printer) {
3075 printer <<
" " << getTarget() <<
" after ";
3076 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
3077 if (staticChunkSize != ShapedType::kDynamic)
3078 printer << staticChunkSize;
3080 printer << getDynamicChunkSizes();
3083 {getStaticChunkSizesAttrName()});
3084 printer <<
" : " << getTarget().getType();
3085 if (staticChunkSize == ShapedType::kDynamic)
3086 printer <<
", " << getDynamicChunkSizes().getType();
3089LogicalResult SplitOp::verify() {
3090 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
3091 (getDynamicChunkSizes() ==
nullptr)) {
3092 return emitOpError() <<
"expects either a dynamic or a static split "
3093 "point to be provided";
3102void transform::SplitReductionOp::build(
3103 OpBuilder &builder, OperationState &
result, Value
target,
3104 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
3105 bool useScalingAlgorithm,
bool useAlloc) {
3108 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(
result.name),
3111 SplitReductionOp::getInsertSplitDimensionAttrName(
result.name),
3113 if (innerParallel) {
3114 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(
result.name),
3117 if (useScalingAlgorithm) {
3119 SplitReductionOp::getUseScalingAlgorithmAttrName(
result.name),
3123 result.addAttribute(SplitReductionOp::getUseAllocAttrName(
result.name),
3126 auto resultType = transform::AnyOpType::get(ctx);
3127 result.addTypes({resultType, resultType, resultType, resultType});
3130DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
3131 transform::TransformRewriter &rewriter, LinalgOp
target,
3132 transform::ApplyToEachResultList &results,
3133 transform::TransformState &state) {
3135 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
3136 unsigned(getInsertSplitDimension()),
3137 bool(getInnerParallel())};
3140 FailureOr<SplitReductionResult> splitResult =
3141 (getUseScalingAlgorithm())
3145 return emitDefaultDefiniteFailure(
target);
3147 results.
push_back(splitResult->initOrAlloc);
3149 results.
push_back(splitResult->splitLinalgOp);
3150 results.
push_back(splitResult->resultCombiningLinalgOp);
3158void transform::TileReductionUsingForOp::build(
3159 OpBuilder &builder, OperationState &
result, Value
target,
3160 ArrayRef<int64_t> staticTileSizes) {
3167 auto opTy = transform::AnyOpType::get(ctx);
3173 staticTileSizesAttr);
3176DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
3177 transform::TransformRewriter &rewriter, Operation *
target,
3178 transform::ApplyToEachResultList &results,
3179 transform::TransformState &state) {
3182 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(
target);
3183 if (!partialReductionOp) {
3186 "Operation should implement PartialReductionOpInterface");
3189 SmallVector<unsigned> reductionDims =
3191 if (reductionDims.empty()) {
3192 for (
auto [idx, iteratorType] :
3193 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3194 if (iteratorType == utils::IteratorType::reduction)
3195 reductionDims.push_back(idx);
3199 scf::SCFTilingOptions
options;
3200 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
3201 options.setReductionTilingStrategy(
3204 options.setReductionDims(reductionDims);
3205 FailureOr<scf::SCFTilingResult>
result =
3206 scf::tileUsingSCF(rewriter, partialReductionOp,
options);
3210 "failed to tile using partial reduction");
3213 for (Value initValue :
result->initialValues)
3215 for (
auto *parallelTiledOp :
result->tiledOps)
3217 for (
auto *mergeOp :
result->mergeOps)
3227void transform::TileReductionUsingForallOp::build(
3228 OpBuilder &builder, OperationState &
result, Value
target,
3229 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3237 auto opTy = transform::AnyOpType::get(ctx);
3244 staticNumThreadsAttr,
3245 staticTileSizesAttr,
3249DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3250 transform::TransformRewriter &rewriter, Operation *
target,
3251 transform::ApplyToEachResultList &results,
3252 transform::TransformState &state) {
3255 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(
target);
3256 if (!partialReductionOp) {
3259 "Operation should implement PartialReductionOpInterface");
3261 SmallVector<OpFoldResult> numThreads =
3263 SmallVector<OpFoldResult> tileSizes =
3266 scf::SCFTilingOptions
options;
3267 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3268 options.setReductionTilingStrategy(
3270 if (!getNumThreads().empty()) {
3271 options.setNumThreads(numThreads);
3273 options.setTileSizes(tileSizes);
3275 if (
auto mapping = getMapping()) {
3276 options.setMapping(mapping.value().getValue());
3278 SmallVector<unsigned> reductionDims =
3280 if (reductionDims.empty()) {
3281 for (
auto [idx, iteratorType] :
3282 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3283 if (iteratorType == utils::IteratorType::reduction)
3284 reductionDims.push_back(idx);
3287 options.setReductionDims(reductionDims);
3288 FailureOr<scf::SCFTilingResult>
result =
3289 scf::tileUsingSCF(rewriter, partialReductionOp,
options);
3292 auto diag = emitSilenceableError() <<
"could not tile reduction";
3297 for (Value initValue :
result->initialValues)
3299 for (
auto *parallelTiledOp :
result->tiledOps)
3301 for (
auto *mergeOp :
result->mergeOps)
3311DiagnosedSilenceableFailure
3312transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3313 TransformResults &transformResults,
3314 TransformState &state) {
3316 SmallVector<Operation *> targetOps =
3319 if (!llvm::hasSingleElement(targetOps)) {
3321 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
3325 Operation *
target = *targetOps.begin();
3326 auto linalgOp = dyn_cast<LinalgOp>(
target);
3327 auto tileableOp = dyn_cast<TilingInterface>(
target);
3332 OpBuilder builder(linalgOp.getContext());
3334 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
3335 if (linalgOp.hasDynamicShape()) {
3336 auto diag = emitSilenceableError()
3337 <<
"cannot compute parametric tile sizes for dynamically "
3338 "shaped payload op";
3339 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
3343 FailureOr<StaticContinuousTileSizeSpecification> spec =
3347 return emitSilenceableError()
3348 <<
"failed to compute multi-size tiling sizes";
3351 SmallVector<int64_t> chunkSizes;
3353 for (
auto &&[tileSize, tripCount] :
3354 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3355 chunkSizes.push_back(tileSize * tripCount);
3357 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3358 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3363 getI64AttrsFromI64(spec->tileSizes));
3364 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
3365 getI64AttrsFromI64(chunkSizes));
3372 OpFoldResult targetSize = builder.
getIndexAttr(getTargetSize());
3373 unsigned dimension = getDimension();
3376 builder, tileableOp, dimension, targetSize,
true);
3378 return emitSilenceableError() <<
"could not generate tile size computation";
3383 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3388 SmallVector<Value> chunkSizes;
3390 for (
auto &&[tileSize, tripCount] :
3391 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3392 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3393 chunkSizes.push_back(splitPoint);
3396 auto getDefiningOps = [&](ArrayRef<Value> values) {
3397 return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3403 getDefiningOps(spec->tileSizes));
3404 transformResults.
set(cast<OpResult>(getChunkSizes()),
3405 getDefiningOps(chunkSizes));
3410LogicalResult transform::ContinuousTileSizesOp::verify() {
3413 return emitOpError() <<
"expects all results type to be the same";
3419void transform::ContinuousTileSizesOp::getEffects(
3420 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3437 Type &tileSizesType,
3438 Type &chunkSizesType) {
3439 FunctionType funcType;
3441 if (failed(parser.
parseType<FunctionType>(funcType)))
3444 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3445 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
3446 "argument and one result";
3448 targetType = funcType.getInput(0);
3449 tileSizesType = chunkSizesType = funcType.getResult(0);
3458void transform::TileUsingForOp::build(
3460 Value
target, ArrayRef<int64_t> staticTileSizes,
3461 ArrayRef<int64_t> interchange,
3462 std::optional<ArrayRef<bool>> scalableSizes) {
3463 return build(builder,
result, loopTypes,
3467 interchange, scalableSizes);
3470void transform::TileUsingForOp::build(
3471 OpBuilder &builder, OperationState &
result, Value
target,
3472 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3473 std::optional<ArrayRef<bool>> scalableSizes) {
3476 interchange, scalableSizes);
3479void transform::TileUsingForOp::build(
3480 OpBuilder &builder, OperationState &
result, Value
target,
3481 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3482 std::optional<ArrayRef<bool>> scalableSizes) {
3485 SmallVector<Type> loopTypes(1, builder.
getType<transform::AnyOpType>());
3486 build(builder,
result, loopTypes,
target, mixedTileSizes, interchange,
3490void transform::TileUsingForOp::build(
3492 Value
target, ArrayRef<OpFoldResult> mixedTileSizes,
3493 ArrayRef<int64_t> interchange,
3494 std::optional<ArrayRef<bool>> scalableSizes) {
3495 SmallVector<int64_t> staticTileSizes;
3496 SmallVector<Value> dynamicTileSizes;
3502 unsigned numExpectedLoops =
3503 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3504 SmallVector<Type> resultTypes;
3505 resultTypes.reserve(numExpectedLoops);
3506 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3507 "expected one loop type or as many as loops");
3508 if (loopTypes.size() == 1)
3509 resultTypes.append(numExpectedLoops, loopTypes[0]);
3511 llvm::append_range(resultTypes, loopTypes);
3512 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(),
false);
3513 if (scalableSizes.has_value())
3514 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3519 staticTileSizesAttr,
3521 expandedScalableSizes);
3524LogicalResult transform::TileUsingForOp::verify() {
3526 return emitOpError(
"expected same number of sizes (")
3528 << getScalableSizes().size() <<
")";
3529 ArrayRef<int64_t> staticSizes = getStaticSizes();
3530 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3531 if (getLoops().size() != numExpectedLoops)
3532 return emitOpError(
"expected number of loops to tile (")
3533 << numExpectedLoops <<
") to match number of `loops` results ("
3534 << getLoops().size() <<
")";
3538DiagnosedSilenceableFailure
3539transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3540 TransformResults &transformResults,
3541 TransformState &state) {
3542 ArrayRef<int64_t> tileSizes = getStaticSizes();
3544 SmallVector<Operation *> targets =
3546 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3547 SmallVector<SmallVector<int64_t>> paramSizes;
3551 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
3552 dynamicSizeProducers.push_back({});
3553 ArrayRef<Attribute> params = state.
getParams(transformValue);
3554 paramSizes.push_back(llvm::map_to_vector(params, [](Attribute attr) {
3555 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3558 if (paramSizes.back().size() != targets.size()) {
3559 DiagnosedSilenceableFailure
diag =
3560 emitSilenceableError()
3561 <<
"expected as many parameter values ("
3562 << dynamicSizeProducers.back().size() <<
") as target ops ("
3563 << targets.size() <<
")";
3564 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
3570 paramSizes.push_back({});
3571 dynamicSizeProducers.push_back(
3574 if (dynamicSizeProducers.back().size() != targets.size()) {
3575 DiagnosedSilenceableFailure
diag =
3576 emitSilenceableError()
3577 <<
"expected as many dynamic size-producing operations ("
3578 << dynamicSizeProducers.back().size() <<
") as target ops ("
3579 << targets.size() <<
")";
3580 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3584 for (Operation *op : dynamicSizeProducers.back()) {
3590 DiagnosedSilenceableFailure
diag =
3591 emitSilenceableError() <<
"expected sizes to be produced by ops "
3592 "with a single index-type result";
3593 diag.attachNote(op->
getLoc()) <<
"size producer op";
3594 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3599 SmallVector<Operation *> tiled;
3600 SmallVector<SmallVector<Operation *, 4>, 4> loops;
3601 loops.resize(getLoops().size());
3602 auto scalableSizes = getScalableSizes();
3603 for (
auto [i, op] : llvm::enumerate(targets)) {
3604 auto tilingInterface = dyn_cast<TilingInterface>(op);
3605 if (!tilingInterface) {
3606 DiagnosedSilenceableFailure
diag =
3607 emitSilenceableError()
3608 <<
"only ops implementing TilingInterface are supported";
3609 diag.attachNote(op->
getLoc()) <<
"target op";
3612 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3613 DiagnosedSilenceableFailure
diag =
3614 emitSilenceableError()
3615 <<
"too many tiles provided, expected at most "
3616 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
3617 << tileSizes.size();
3618 diag.attachNote(op->
getLoc()) <<
"target op";
3622 scf::SCFTilingOptions tilingOptions;
3623 if (tileSizes.empty()) {
3624 tilingOptions.setTileSizeComputationFunction(
3625 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
3629 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &
b,
3631 SmallVector<OpFoldResult> sizes;
3632 sizes.reserve(tileSizes.size());
3633 unsigned dynamicIdx = 0;
3635 for (
auto [ofrIdx, ofr] : llvm::enumerate(
getMixedSizes())) {
3636 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3637 if (scalableSizes[ofrIdx]) {
3639 b, getLoc(), cast<IntegerAttr>(attr).getInt());
3641 vector::VectorScaleOp::create(
b, getLoc(),
b.getIndexType());
3643 arith::MulIOp::create(
b, getLoc(), val, vscale).getResult());
3645 sizes.push_back(attr);
3649 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3650 ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3652 assert((dynamicSizes.empty() ^ params.empty()) &&
3653 "expected either dynamic sizes or parameters");
3654 if (!params.empty()) {
3655 sizes.push_back(
b.getIndexAttr(params[index]));
3657 sizes.push_back(dynamicSizes[index]->getResult(0));
3664 tilingOptions.setInterchange(getInterchange());
3665 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3666 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3667 if (
failed(maybeTilingResult))
3670 rewriter.
replaceOp(op, maybeTilingResult->replacements);
3672 tiled.append(maybeTilingResult->tiledOps);
3673 for (
const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3674 loops[en2.index()].push_back(en2.value());
3677 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3678 for (
const auto &en : llvm::enumerate(loops))
3679 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3684SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
3686 ArrayRef<int64_t> tileSizes = getStaticSizes();
3687 SmallVector<OpFoldResult> results;
3688 results.reserve(tileSizes.size());
3689 unsigned dynamicPos = 0;
3691 for (int64_t size : tileSizes) {
3692 if (size == ShapedType::kDynamic) {
3693 results.push_back(dynamic[dynamicPos++]);
3701void transform::TileUsingForOp::getEffects(
3702 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3713void transform::TileUsingForallOp::build(OpBuilder &builder,
3715 ArrayRef<int64_t> staticTileSizes,
3716 transform::TileSizesSpec,
3718 return build(builder,
result,
3726void transform::TileUsingForallOp::build(OpBuilder &builder,
3728 ArrayRef<OpFoldResult> mixedTileSizes,
3729 transform::TileSizesSpec,
3731 SmallVector<int64_t> staticTileSizes;
3732 SmallVector<Value> dynamicTileSizes;
3738 auto operationType = transform::AnyOpType::get(ctx);
3741 TypeRange{operationType, operationType},
3748 staticTileSizesAttr,
3752void transform::TileUsingForallOp::build(OpBuilder &builder,
3754 ArrayRef<int64_t> staticNumThreads,
3755 transform::NumThreadsSpec,
3759 NumThreadsSpec(), mapping);
3762void transform::TileUsingForallOp::build(OpBuilder &builder,
3764 ArrayRef<OpFoldResult> mixedNumThreads,
3765 transform::NumThreadsSpec,
3767 SmallVector<int64_t> staticNumThreads;
3768 SmallVector<Value> dynamicNumThreads;
3775 auto operationType = transform::AnyOpType::get(ctx);
3778 TypeRange{operationType, operationType},
3784 staticNumThreadsAttr,
3791static SmallVector<OpFoldResult>
3797 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3799 for (
auto [lb,
ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3801 rewriter, loc, normalizedUbExpr, {lb,
ub, step});
3802 normalizedUbs.push_back(normalizedUb);
3804 return normalizedUbs;
3820 for (
auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3823 denormalizedIvs.push_back(
3826 return denormalizedIvs;
3837 scf::ForallOp loop) {
3854 auto normalizedForallOp = scf::ForallOp::create(
3855 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3856 loop.getOutputs(), loop.getMapping(),
3859 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3861 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3866 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3867 normalizedForallOp.getRegionIterArgs().end());
3868 Block *origLoopBlock = loop.getBody();
3869 rewriter.
mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3871 rewriter.
replaceOp(loop, normalizedForallOp);
3872 return normalizedForallOp;
3880 scf::SCFTilingResult &tilingResult) {
3882 auto tileableOp = dyn_cast<TilingInterface>(
target);
3885 transformOp.emitSilenceableError()
3886 <<
"only TilingInterface ops are supported";
3887 diag.attachNote(
target->getLoc()) <<
"target op";
3891 scf::SCFTilingOptions
options;
3892 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3893 if (!mixedNumThreads.empty()) {
3894 options.setNumThreads(mixedNumThreads);
3896 options.setTileSizes(mixedTileSizes);
3899 options.setMapping(mapping.value().getValue());
3901 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3902 scf::tileUsingSCF(rewriter, tileableOp,
options);
3904 if (failed(maybeTilingResult))
3905 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3907 rewriter.
replaceOp(tileableOp, maybeTilingResult->replacements);
3909 tilingResult = *maybeTilingResult;
3911 if (mixedNumThreads.empty()) {
3912 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3915 scf::ForallOp normalizedForallOp =
3917 tilingResult.loops.front() = normalizedForallOp;
3927 auto transformOp = cast<TransformOpInterface>(getOperation());
3936 getPackedNumThreads()
3938 state, transformOp, mixedNumThreads, getPackedNumThreads())
3940 state, transformOp, mixedNumThreads, getMixedNumThreads());
3944 status = getPackedTileSizes()
3946 state, transformOp, mixedTileSizes, getPackedTileSizes())
3948 state, transformOp, mixedTileSizes, getMixedTileSizes());
3953 scf::SCFTilingResult tilingResult;
3955 rewriter, state, transformOp,
target, mixedNumThreads, mixedTileSizes,
3956 getMapping(), tilingResult);
3957 if (!
diag.succeeded())
3959 tileOps.push_back(tilingResult.loops.front());
3960 tiledOps.append(tilingResult.tiledOps);
3963 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3964 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3969void transform::TileUsingForallOp::getEffects(
3970 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3980SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3985SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3990LogicalResult TileUsingForallOp::verify() {
3991 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3992 static_cast<int>(getPackedNumThreads() != Value());
3993 if (numThreadsSpec > 1)
3995 "num_threads and packed_num_threads are mutually exclusive");
3996 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3997 static_cast<int>(getPackedTileSizes() != Value());
3998 if (tileSizesSpec > 1)
4000 "tile_sizes and packed_tile_sizes are mutually exclusive");
4001 if (numThreadsSpec == 0 && tileSizesSpec == 0)
4002 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
4003 "must be specified");
4011void transform::VectorizeChildrenAndApplyPatternsOp::build(
4012 OpBuilder &builder, OperationState &
result, Value
target,
4013 bool foldTypeExtensionsIntoContract,
bool vectorizePadding,
4014 bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
4016 if (foldTypeExtensionsIntoContract) {
4018 VectorizeChildrenAndApplyPatternsOp::
4019 getFoldTypeExtensionsIntoContractAttrName(
result.name),
4022 if (vectorizePadding) {
4024 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
4028 if (vectorizeExtract) {
4030 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
4034 if (flatten1DDepthwiseConv) {
4036 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
4046struct VectorizationPattern :
public RewritePattern {
4047 explicit VectorizationPattern(MLIRContext *context,
4048 bool vectorizeExtract =
false,
4049 bool flattenConv =
false)
4050 : RewritePattern(MatchAnyOpTypeTag(), 1, context),
4051 vectorizeNDExtract(vectorizeExtract),
4052 flatten1DDepthwiseConv(flattenConv) {}
4053 LogicalResult matchAndRewrite(Operation *op,
4054 PatternRewriter &rewriter)
const override {
4057 "Unsupported Op, cannot vectorize");
4058 FailureOr<VectorizationResult> vectorResults =
4060 {}, vectorizeNDExtract,
4061 flatten1DDepthwiseConv);
4062 if (
failed(vectorResults))
4064 rewriter.
replaceOp(op, vectorResults->replacements);
4071 bool vectorizeNDExtract =
false;
4075 bool flatten1DDepthwiseConv =
false;
4079DiagnosedSilenceableFailure
4080transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
4081 transform::TransformRewriter &rewriter, Operation *
target,
4082 transform::ApplyToEachResultList &results,
4083 transform::TransformState &state) {
4084 if (!
target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
4085 auto diag = this->
emitOpError(
"requires isolated-from-above targets");
4086 diag.attachNote(
target->getLoc()) <<
"non-isolated target";
4091 RewritePatternSet patterns(ctx);
4092 patterns.
add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
4093 getFlatten_1dDepthwiseConv());
4095 if (!getDisableTransferPermutationMapLoweringPatterns())
4098 if (!getDisableMultiReductionToContractPatterns())
4103 patterns.
add<linalg::LinalgCopyVTRForwardingPattern,
4104 linalg::LinalgCopyVTWForwardingPattern>(ctx,
4106 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
4107 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
4110 patterns.
add<CopyVectorizationPattern>(ctx);
4112 if (getFoldTypeExtensionsIntoContract())
4115 if (getVectorizePadding()) {
4123 TrackingListener listener(state, *
this);
4126 GreedyRewriteConfig().setListener(&listener))))
4127 return emitDefaultDefiniteFailure(
target);
4137DiagnosedSilenceableFailure transform::VectorizeOp::apply(
4138 transform::TransformRewriter &rewriter,
4139 mlir::transform::TransformResults &transformResults,
4140 mlir::transform::TransformState &state) {
4142 if (std::empty(targets))
4144 auto transformOp = cast<TransformOpInterface>(getOperation());
4145 SmallVector<int64_t> vectorSizes;
4147 state, transformOp, getMixedVectorSizes(), vectorSizes);
4152 for (Operation *
target : targets) {
4155 <<
"Unsupported Op, cannot vectorize";
4157 FailureOr<VectorizationResult> vectorResults =
4159 getVectorizeNdExtract().value_or(
false),
4161 getAssumeDynamicDimsMatchVecSizes().value_or(
false),
4162 getCreateNamedContraction().value_or(
false));
4163 if (
failed(vectorResults)) {
4165 <<
"Attempted to vectorize, but failed";
4173void transform::VectorizeOp::getEffects(
4174 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4180SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
4185LogicalResult transform::VectorizeOp::verify() {
4186 if (getStaticVectorSizes().size() != getScalableSizes().size())
4187 return emitOpError(
"expected same number of vector sizes (")
4188 << getStaticVectorSizes().size() <<
") and scalable sizes ("
4189 << getScalableSizes().size() <<
")";
4197DiagnosedSilenceableFailure
4198transform::HoistRedundantVectorTransfersOp::applyToOne(
4199 transform::TransformRewriter &rewriter, func::FuncOp
target,
4200 transform::ApplyToEachResultList &results,
4201 transform::TransformState &state) {
4214DiagnosedSilenceableFailure
4215transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4216 transform::TransformRewriter &rewriter, mlir::Operation *
target,
4217 transform::ApplyToEachResultList &results,
4218 transform::TransformState &state) {
4229DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
4230 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4231 transform::ApplyToEachResultList &results,
4232 transform::TransformState &state) {
4234 auto maybeTransformed =
4237 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4240 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4243 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4246 .Case([&](linalg::Conv2DNchwFchwOp op) {
4249 .Default([&](Operation *op) {
4252 if (
failed(maybeTransformed))
4253 return emitDefaultSilenceableFailure(
target);
4255 results.
push_back(maybeTransformed->first);
4257 results.
push_back(maybeTransformed->second);
4265DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4266 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4267 transform::ApplyToEachResultList &results,
4268 transform::TransformState &state) {
4272 <<
"only elementwise flattening is supported";
4275 if (
target.getNumLoops() <= 1) {
4282 std::iota(reassociation.begin(), reassociation.end(), 0);
4283 auto maybeFlattened =
4285 if (
failed(maybeFlattened))
4287 <<
"attempted to flatten, but failed";
4288 results.
push_back(maybeFlattened->collapsedOp);
4297DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4298 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4299 transform::ApplyToEachResultList &results,
4300 transform::TransformState &state) {
4302 auto maybeTransformed =
4304 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4307 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4310 .Default([&](Operation *op) {
4313 if (
failed(maybeTransformed))
4314 return emitDefaultSilenceableFailure(
target);
4324DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4325 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4326 transform::ApplyToEachResultList &results,
4327 transform::TransformState &state) {
4329 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4330 auto maybeTransformed =
4332 .Case([&](linalg::MatmulOp op) {
4335 .Case([&](linalg::BatchMatmulOp op) {
4338 .Default(failure());
4339 if (
failed(maybeTransformed))
4349template <
typename OpTy>
4350static DiagnosedSilenceableFailure
4354 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4355 tensor::ParallelInsertSliceOp>() &&
4358 if (
auto copySource =
4359 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4367 if (isa<mlir::ParallelCombiningOpInterface>(
target.getOperation()))
4370 Value extracted = tensor::ExtractSliceOp::create(
4373 Value copied = linalg::CopyOp::create(rewriter,
target.getLoc(),
4374 target.getSource(), extracted)
4386DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4387 transform::TransformRewriter &rewriter, Operation *targetOp,
4388 transform::ApplyToEachResultList &results,
4389 transform::TransformState &state) {
4392 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4393 return doit(rewriter,
target, results, state);
4394 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4395 return doit(rewriter,
target, results, state);
4397 DiagnosedSilenceableFailure
diag =
4398 emitSilenceableError()
4399 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4400 diag.attachNote(targetOp->
getLoc()) <<
"target op";
4408DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4409 transform::TransformRewriter &rewriter, Operation *
target,
4410 transform::ApplyToEachResultList &results,
4411 transform::TransformState &state) {
4413 if (!isa<linalg::CopyOp, tensor::PadOp>(
target)) {
4414 DiagnosedSilenceableFailure
diag =
4415 emitSilenceableError()
4416 <<
"only linalg.copy and tensor.pad target ops are supported";
4417 diag.attachNote(
target->getLoc()) <<
"target op";
4420 assert(
target->getNumResults() == 1 &&
"expected single result");
4421 auto resultShapedType = cast<ShapedType>(
target->getResult(0).getType());
4422 if (!resultShapedType.hasStaticShape()) {
4423 DiagnosedSilenceableFailure
diag =
4424 emitSilenceableError()
4425 <<
"only statically sized ops of rank <= 3 are supported";
4426 diag.attachNote(
target->getLoc()) <<
"target op";
4431 int64_t desiredBitAlignment = getDesiredBitAlignment();
4432 int64_t eltBitwidth =
4433 resultShapedType.getElementType().getIntOrFloatBitWidth();
4434 if (desiredBitAlignment % eltBitwidth != 0) {
4435 desiredBitAlignment = eltBitwidth;
4438 gpu::CopyMappingInfo mapping(
4440 getTotalNumThreads(),
4441 desiredBitAlignment,
4442 resultShapedType.getShape(),
4445 resultShapedType.getElementType().getIntOrFloatBitWidth());
4446 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4447 DiagnosedSilenceableFailure
diag =
4448 emitSilenceableError()
4449 <<
"too few threads to map copy op to threads on the most minor "
4450 "dimension, given alignment and vector size constraints, try "
4451 "smaller tile size of mapping to more threads";
4452 diag.attachNote(
target->getLoc()) <<
"target op";
4458 scf::SCFTilingResult tilingResult;
4465 ArrayRef<OpFoldResult>{},
4466 b.getArrayAttr(mapping.threadMapping),
4468 if (!
diag.succeeded())
4471 results.
push_back(tilingResult.loops.front());
4472 for (
auto *op : tilingResult.tiledOps)
4481DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4482 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4483 transform::ApplyToEachResultList &results,
4484 transform::TransformState &state) {
4486 FailureOr<Operation *> maybeTransformed = failure();
4488 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4493 .Default([&](Operation *op) {
return false; });
4496 return emitSilenceableError()
4497 <<
"this operation is not supported to convert to Winograd Conv2D";
4500 if (
failed(maybeTransformed)) {
4501 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
4508DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4509 transform::TransformRewriter &rewriter, Operation *
target,
4510 transform::ApplyToEachResultList &results,
4511 transform::TransformState &state) {
4513 FailureOr<Operation *> maybeTransformed = failure();
4516 .Case([&](linalg::WinogradFilterTransformOp op) {
4520 .Case([&](linalg::WinogradInputTransformOp op) {
4524 .Case([&](linalg::WinogradOutputTransformOp op) {
4531 DiagnosedSilenceableFailure
diag =
4532 emitSilenceableError()
4533 <<
"this operation is not supported to decompose into other operations";
4534 diag.attachNote(
target->getLoc()) <<
"target op";
4538 if (
failed(maybeTransformed)) {
4539 DiagnosedSilenceableFailure
diag =
4540 emitSilenceableError() <<
"decompose Winograd operations failed";
4541 diag.attachNote(
target->getLoc()) <<
"target op";
4549#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4551#define GET_OP_CLASSES
4552#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static llvm::ManagedStatic< PassManagerOptions > options
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Base type for affine expression.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the error.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation, bool PoisonPaddingOk=false)
Patterns to bubble up or down data layout ops across other operations.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice + copy.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Canonicalization patterns relevant to apply after tiling patterns.
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors and memref.
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SmallVector< IntTy > extractFromIntegerArrayAttr(Attribute attr)
Extract integer values from the assumed ArrayAttr of IntegerAttr.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...