42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/ScopeExit.h"
44#include "llvm/ADT/SmallPtrSet.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/DebugLog.h"
48#include "llvm/Support/LogicalResult.h"
55#define DEBUG_TYPE "linalg-transforms"
62template <
typename PatternTy,
typename... Args>
65 using OpTy =
typename llvm::function_traits<
66 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
67 auto op = dyn_cast<OpTy>(operation);
72 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
77 auto result = pattern.returningMatchAndRewrite(op, rewriter);
80 return cast<LinalgOp>(
result->getOperation());
90 if (
auto attr = dyn_cast<Attribute>(ofr)) {
91 if (!isa<IntegerAttr>(attr))
92 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
97 Value transformValue = cast<Value>(ofr);
98 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
100 if (params.size() != 1)
101 return transformOp.emitDefiniteFailure()
102 <<
"requires exactly one parameter associated";
103 result.push_back(params[0]);
108 if (!llvm::hasSingleElement(payloadOps)) {
110 transformOp.emitSilenceableError()
111 <<
"handle must be mapped to exactly one payload op";
113 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
120 transformOp.emitSilenceableError()
121 <<
"payload op must have exactly 1 index result";
141 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
143 for (
auto param : params) {
144 if (!isa<IntegerAttr>(param))
145 return transformOp.emitDefiniteFailure()
146 <<
"expected the parameter to be associated with an integer "
154 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
156 transformOp.emitSilenceableError()
157 <<
"payload op must have exactly 1 index result";
158 diag.attachNote(op->getLoc())
159 <<
"has " << op->getNumResults() <<
" results";
162 result.push_back(op->getResult(0));
176 if (
auto attr = dyn_cast<Attribute>(paramOrHandle)) {
177 reified.push_back(cast<IntegerAttr>(attr).getInt());
180 if (isa<TransformParamTypeInterface>(
181 cast<Value>(paramOrHandle).
getType())) {
183 if (params.size() != 1)
184 return transformOp.emitSilenceableError() <<
"expected a single param";
186 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
190 Value handle = cast<Value>(paramOrHandle);
191 if (!isa<TransformHandleTypeInterface>(handle.getType()))
192 return transformOp.emitSilenceableError() <<
"unexpected value handle";
194 if (!llvm::hasSingleElement(payload))
195 return transformOp.emitSilenceableError()
196 <<
"requires param or handle that is mapped to 1 payload op";
198 Operation *paramOrHandlePayloadOp = *payload.begin();
201 return transformOp.emitSilenceableError()
202 <<
"requires param or handle to be result of op with 1 index "
208 return transformOp.emitSilenceableError()
209 <<
"requires param or handle to be the result of a constant like "
212 reified.push_back(attr.getInt());
221void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
226void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
231void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
236void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
242void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
245 options.rankReductionStrategy =
250void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
255void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
260void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
265void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
270void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
275void transform::ApplyDataLayoutPropagationPatternsOp::populatePatterns(
284void transform::ApplyExtractSliceSinkingPatternsOp::populatePatterns(
288 Operation *producer = opOperand->get().getDefiningOp();
289 Operation *consumer = opOperand->getOwner();
304 SmallVector<Operation *> getNewOps()
const {
305 return SmallVector<Operation *>(newOps.begin(), newOps.end());
309 void notifyOperationInserted(Operation *op,
310 OpBuilder::InsertPoint previous)
override {
311 ForwardingListener::notifyOperationInserted(op, previous);
313 if (previous.
isSet())
317 assert(
inserted.second &&
"expected newly created op");
320 void notifyOperationErased(Operation *op)
override {
321 ForwardingListener::notifyOperationErased(op);
322 op->
walk([&](Operation *op) { newOps.erase(op); });
334 llvm::scope_exit resetListener(
335 [&]() { rewriter.
setListener(previousListener); });
336 NewOpsListener newOpsListener(previousListener);
340 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
341 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
342 MaterializeInDestination;
343 }
else if (getMemcpyOp() ==
"memref.copy") {
346 }
else if (getMemcpyOp() ==
"linalg.copy") {
350 llvm_unreachable(
"invalid memcpy op");
352 if (getAllocOp() ==
"memref.alloc") {
355 }
else if (getAllocOp() ==
"memref.alloca") {
359 llvm_unreachable(
"invalid alloc op");
361 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
362 options.emitDealloc = getEmitDealloc();
366 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
373 <<
"failed to bufferize operation";
374 diag.attachNote(op->
getLoc()) <<
"target payload op";
377 allocatedBuffers.push_back(buffer);
381 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
382 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
386void transform::BufferizeToAllocationOp::getEffects(
388 if (getBufferizeDestinationOnly()) {
399LogicalResult transform::BufferizeToAllocationOp::verify() {
400 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
401 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
403 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
416 auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.
getOwner());
423 Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
431 if (!isa<TensorType, FloatType, IntegerType>(value.
getType()))
433 return llvm::any_of(value.
getUses(),
443 auto type = dyn_cast<RankedTensorType>(
tensor.getType());
445 return emitSilenceableError() <<
"non-tensor type: " <<
tensor;
459 for (
auto [pos, dim] : llvm::enumerate(type.getShape())) {
460 if (!ShapedType::isDynamic(dim))
465 tensor::DimOp::create(rewriter,
tensor.getLoc(),
tensor, cst);
466 preservedOps.insert(dimOp);
467 dynamicDims.push_back(dimOp);
469 auto allocation = bufferization::AllocTensorOp::create(
470 rewriter,
tensor.getLoc(), type, dynamicDims);
472 if (getMemorySpaceAttr())
473 allocation.setMemorySpaceAttr(getMemorySpaceAttr());
474 Value allocated = allocation;
478 if (needsMaterialization) {
479 auto copy = bufferization::MaterializeInDestinationOp::create(
481 preservedOps.insert(
copy);
482 promoted.push_back(
copy.getResult());
484 promoted.push_back(allocated);
488 results.
setValues(cast<OpResult>(getPromoted()), promoted);
492void transform::PromoteTensorOp::getEffects(
508 FailureOr<linalg::LinalgOp> res =
510 if (succeeded(res)) {
514 return emitDefaultSilenceableFailure(
target);
528 auto decomposableOp = dyn_cast<AggregatedOpInterface>(
target);
529 if (!decomposableOp) {
531 "payload is not a decomposable op"));
532 return emitDefaultSilenceableFailure(
target);
535 FailureOr<SmallVector<Value>> maybeNewResults =
536 decomposableOp.decomposeOperation(rewriter);
537 if (
failed(maybeNewResults))
538 return emitDefaultSilenceableFailure(
target);
540 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
541 for (
Value val : *maybeNewResults) {
542 Operation *definition = val.getDefiningOp();
553void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
560transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
564 options.allowReturnAllocsFromLoops =
true;
570 <<
"failed to analyze op";
572 rewriter,
target, state)))
574 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
587 bool applyCleanup,
bool useForall) {
589 builder,
result, loopTypes,
595 applyCleanup, useForall);
601 bool applyCleanup,
bool useForall) {
609 applyCleanup, useForall);
616 bool applyCleanup,
bool useForall) {
620 build(builder,
result, loopTypes,
target, mixedTileSizes,
621 mixedTileInterchange, applyCleanup, useForall);
628 bool applyCleanup,
bool useForall) {
635 staticTileInterchange);
640 auto staticTileInterchangeAttr =
642 unsigned numExpectedLoops =
643 useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
645 resultTypes.reserve(numExpectedLoops);
646 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
647 "expected one loop type or as many as loops");
648 if (loopTypes.size() == 1)
649 resultTypes.append(numExpectedLoops, loopTypes[0]);
651 llvm::append_range(resultTypes, loopTypes);
656 dynamicTileInterchange,
659 staticTileInterchangeAttr,
666template <
typename Range>
671 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
675 size_t numTargets = llvm::range_size(payloadOps);
678 auto tilingInterfaceOp = dyn_cast<TilingInterface>(
target);
679 if (!tilingInterfaceOp)
680 return transformOp->
emitError(
"only TilingInterface ops are supported");
683 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
684 applyFn(tilingInterfaceOp);
685 if (failed(tiledResults))
690 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
691 for (
Operation *toReplace : opsToReplace) {
692 for (
OpResult res : toReplace->getResults())
693 if (
auto replacement = tiledResults->replacements.lookup(res))
695 if (toReplace->use_empty()) {
701 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
702 assert(tiledResults->loops.size() == numLoops &&
703 "Mismatched number of loops, tile and fuse transform should have "
705 for (
unsigned int i = 0; i < numLoops; ++i)
706 loopOps[i].
push_back(tiledResults->loops[i]);
709 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
718 for (
unsigned int idx = 0; idx < numTargets; ++idx)
719 for (
unsigned int i = 0; i < numLoops; ++i)
720 flattenedLoopOps.push_back(loopOps[i][idx]);
721 transformResults.
set(transformOp->
getOpResult(1), flattenedLoopOps);
723 for (
unsigned int i = 0; i < numLoops; ++i)
724 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
734 auto transformOp = cast<TransformOpInterface>(getOperation());
740 state, transformOp, mixedTileSizes, getPackedTileSizes())
742 state, transformOp, mixedTileSizes, getMixedTileSizes());
747 state, transformOp, getMixedTileInterchange(), tileInterchange);
751 scf::SCFTilingOptions tilingOptions;
752 tilingOptions.interchangeVector = tileInterchange;
753 bool useForall = getUseForall();
754 tilingOptions.setLoopType(useForall
755 ? scf::SCFTilingOptions::LoopType::ForallOp
756 : scf::SCFTilingOptions::LoopType::ForOp);
757 tilingOptions = tilingOptions.setTileSizes(mixedTileSizes);
758 scf::SCFTileAndFuseOptions tileAndFuseOptions;
759 tileAndFuseOptions.tilingOptions = tilingOptions;
761 if (getApplyCleanup()) {
764 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
767 tileAndFuseOptions.cleanupPatterns = std::move(patterns);
774 numLoops = llvm::count_if(mixedTileSizes, [](
OpFoldResult ofr) {
775 auto attr = dyn_cast<Attribute>(ofr);
778 return cast<IntegerAttr>(attr).getInt() != 0;
782 rewriter, getOperation(), state.
getPayloadOps(getTarget()), numLoops,
783 transformResults, getPackedTileSizes() !=
nullptr,
784 [&](TilingInterface tilingInterfaceOp)
785 -> FailureOr<scf::SCFTileAndFuseResult> {
786 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
793LogicalResult transform::FuseOp::verify() {
794 bool hasPackedTiles = getPackedTileSizes() !=
nullptr;
795 if (!getMixedTileSizes().empty() && hasPackedTiles)
797 "tile_sizes and packed_tile_sizes are mutually exclusive");
799 auto iterspace_rank = getStaticTileSizes().size();
801 if (permutation.size() > iterspace_rank)
803 <<
"interchange length exceeds iteration space dimensions ("
804 << iterspace_rank <<
"), found " << getTileInterchange();
806 for (
int64_t v : permutation) {
807 if (!ShapedType::isDynamic(v)) {
808 if (v < 0 || v >=
static_cast<int64_t>(iterspace_rank))
809 return emitOpError() <<
"expects interchange values to be in range [0, "
810 << iterspace_rank <<
"), found: " << v;
812 return emitOpError() <<
"found duplicate interchange value: " << v;
818 size_t numExpectedLoops = getUseForall() || hasPackedTiles
820 : sizes.size() - llvm::count(sizes, 0);
821 if (numExpectedLoops != getNumResults() - 1)
822 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
832 return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
836void transform::FuseOp::getEffects(
850void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
853 Value containingOp) {
854 result.addOperands({producerOp, containingOp});
855 auto resultType = transform::AnyOpType::get(builder.
getContext());
856 result.addTypes({resultType, resultType});
872 (domInfo.
dominates(containingOp, user))) {
873 dominatedUsers.insert(user);
876 if (dominatedUsers.empty())
880 auto forallOp = cast<scf::ForallOp>(containingOp);
886 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
891 newOuts.push_back(outputs[resultNumber]);
894 auto newforallOp = scf::ForallOp::create(
895 rewriter, loc, forallOp.getMixedLowerBound(),
896 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
897 forallOp.getMapping());
899 newforallOp.getRegion().takeBody(forallOp.getRegion());
904 newforallOp.getBody()->addArgument(newOuts.back().getType(),
905 newOuts.back().getLoc());
906 auto bbArgs = newforallOp.getBody()->getArguments();
909 Operation *op = use.getOwner();
910 return newforallOp->isProperAncestor(op);
914 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
916 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; });
917 Operation *firstYieldOp = yieldingOps.front();
920 Value dst = newforallOp.getRegionIterArgs().back();
922 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->
getLoc(), src,
923 dst, offsets, sizes, strides);
925 for (
auto result : llvm::enumerate(forallOp.getResults())) {
927 newforallOp->getResult(
result.index()));
930 newforallOp->getResults().back(),
932 Operation *user = use.getOwner();
933 return dominatedUsers.contains(user);
947 destWorklist.push_back(dst);
949 while (!destWorklist.empty()) {
950 Value currentDst = destWorklist.pop_back_val();
954 if (src == currentDst)
959 auto bbArg = dyn_cast<BlockArgument>(currentDst);
963 Block *parentBlock = bbArg.getOwner();
964 assert(parentBlock &&
"unlinked block argument");
967 assert(parentOp &&
"expected block argument with parent operation");
970 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
974 for (
auto innerIterArg : parentLoop.getRegionIterArgs()) {
976 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
977 Value loopBlockArgument =
979 destWorklist.push_back(loopBlockArgument);
992static std::tuple<SmallVector<Operation *>,
Operation *>
995 LDBG() <<
"Try to fuse a direct extract use";
996 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
997 if (!tileableProducer) {
999 <<
"producer is not a TileableInterface: " << *producerOp;
1006 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
1007 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1008 return sliceOp && containingOp->isProperAncestor(sliceOp);
1012 if (it == tileableProducer->getUsers().end()) {
1013 diag.attachNote(tileableProducer->getLoc())
1014 <<
"could not find fusion opportunity for: " << *tileableProducer;
1017 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
1030 if (LoopLikeOpInterface containerLoop =
1031 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
1037 auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(
clone);
1041 for (
OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
1042 Value producerOperand =
1043 clone->getOperand(initOperandPtr.getOperandNumber());
1045 containerLoop.getRegionIterArgs()) {
1046 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
1047 Value consumerOperand =
1051 initOperandPtr.set(containerIterArg);
1057 tileableProducer = dyn_cast<TilingInterface>(
clone);
1062 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
1063 LDBG() <<
"resultNumber: " << resultNumber;
1068 FailureOr<TilingResult> tileAndFuseResult =
1069 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
1072 if (failed(tileAndFuseResult)) {
1073 diag.attachNote(tileableProducer->getLoc())
1074 <<
"failed to tile producer op: " << *tileableProducer;
1079 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
1080 LDBG() <<
"tiledProducer: " << *tiledOp;
1085 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1086 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1087 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1088 if (failed(maybeRankReduced)) {
1090 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
1091 << tileAndFuseResult->tiledValues[0]
1092 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
1095 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
1099 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
1100 resultNumber, offsets, sizes);
1103 if (isa<LoopLikeOpInterface>(containingOp))
1104 rewriter.
eraseOp(tileableProducer);
1106 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
1119 LDBG() <<
"Try to fuse an extract use through block argument";
1121 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
1122 if (!tileableProducer) {
1124 <<
"producer is not a TileableInterface: " << *producerOp;
1129 scf::ForallOp forallOp;
1130 auto itProducerUses =
1131 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
1132 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
1136 if (!forallOp || forallOp != containingOp) {
1137 diag.attachNote(tileableProducer->getLoc())
1138 <<
"could not find a use by the containing op: " << *tileableProducer;
1153 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1154 return sliceOp && containingOp->isProperAncestor(sliceOp);
1158 if (itBBArgUsers == bbArg.
getUsers().end()) {
1160 <<
"could not find fusion opportunity for bbArg: " << bbArg;
1163 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
1171 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
1172 LDBG() <<
"resultNumber: " << resultNumber;
1177 rewriter, tileableProducer->getLoc(), tileableProducer,
1178 destinationTensors))) {
1179 diag.attachNote(tileableProducer->getLoc())
1180 <<
"failed to get destination tensors for: " << *tileableProducer;
1185 bvm.
map(destinationTensors[resultNumber], bbArg);
1186 auto tileableProducerClone =
1187 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
1188 llvm::scope_exit scopeGuard(
1189 [&]() { rewriter.
eraseOp(tileableProducerClone); });
1192 FailureOr<TilingResult> tileAndFuseResult =
1193 tileableProducerClone.generateResultTileValue(
1194 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
1195 sliceOpToTile.getMixedSizes());
1196 if (failed(tileAndFuseResult)) {
1197 diag.attachNote(tileableProducer->getLoc())
1198 <<
"failed to tile producer op: " << *tileableProducer;
1203 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1204 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1205 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1206 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
1207 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
1212 destinationTensors.front());
1215 return tileAndFuseResult->tiledOps;
1221 LDBG() <<
"Try to fuse an use by cloning";
1228 uses.push_back(&use);
1233 if (containingOp == use.getOwner()) {
1235 <<
"producer op use by containing op cannot be fused by cloning";
1243 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
1252 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
1253 "Parallel insert slice is not a valid clone destination");
1254 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
1255 LDBG() <<
"resultNumber: " << resultNumber;
1259 fusedOp = rewriter.
clone(*producerOp);
1261 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1266bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1277 auto containingOps = state.
getPayloadOps(getContainingOp());
1278 if (!llvm::hasSingleElement(containingOps)) {
1280 <<
"requires exactly one containing_op handle (got "
1281 << llvm::range_size(containingOps) <<
")";
1283 Operation *containingOp = *containingOps.begin();
1286 if (std::empty(producerOps)) {
1288 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1295 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1296 for (
const auto &it :
enumerate(remainingProducers)) {
1299 int64_t numUsesInContainingOp =
1301 return containingOp->isAncestor(op);
1306 if (numUsesInContainingOp > 0) {
1307 if (numUsesInContainingOp == 1)
1308 remainingProducers.erase(remainingProducers.begin() + it.index());
1315 while (!remainingProducers.empty()) {
1316 auto nextProducer = getNextProducer();
1317 if (
failed(nextProducer)) {
1319 <<
"could not find next producer to fuse into container";
1320 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
1328 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
1335 auto [tiledOps, newContainingOp] =
1337 if (!tiledOps.empty()) {
1338 LDBG() <<
"\nFused a direct extract use\n" << *containingOp;
1339 fusedOps.append(tiledOps);
1340 if (newContainingOp) {
1348 LogicalResult replacementStatus =
1351 (
void)replacementStatus;
1352 assert(succeeded(replacementStatus) &&
1353 "unable to update transform state mapping");
1354 rewriter.
eraseOp(containingOp);
1355 containingOp = newContainingOp;
1362 rewriter,
diag, producerOp, containingOp);
1363 if (!tiledContainingOpOperand.empty()) {
1364 LDBG() <<
"\nFused an extract use through block argument\n"
1366 fusedOps.append(tiledContainingOpOperand);
1373 LDBG() <<
"\nFused an use by cloning\n" << *containingOp;
1374 fusedOps.push_back(cloned);
1380 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1381 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1385void transform::FuseIntoContainingOp::getEffects(
1403 if (isa<GenericOp>(
target)) {
1409 if (succeeded(generic)) {
1410 results.
push_back(generic->getOperation());
1413 return emitDefaultSilenceableFailure(
target);
1426 if (!isa<GenericOp>(
target)) {
1433 FailureOr<LinalgOp> named =
1435 if (succeeded(named)) {
1436 results.
push_back(named->getOperation());
1439 return emitDefaultSilenceableFailure(
target);
1453 if (interchangeVector.empty()) {
1458 unsigned numLoops = cast<LinalgOp>(
target.getOperation()).getNumLoops();
1459 if (interchangeVector.size() != numLoops) {
1460 return emitSilenceableError()
1461 << getIteratorInterchangeAttrName() <<
" has length ("
1462 << interchangeVector.size()
1463 <<
") different from the number of loops in the target operation ("
1474LogicalResult transform::InterchangeOp::verify() {
1476 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1477 if (!std::is_permutation(sequence.begin(), sequence.end(),
1478 permutation.begin(), permutation.end())) {
1480 <<
"expects iterator_interchange to be a permutation, found "
1481 << getIteratorInterchange();
1496 if (!isa<linalg::CopyOp>(targetOp)) {
1498 emitSilenceableError() <<
"only linalg.copy target ops are supported";
1499 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1503 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1504 if (!copyOp.hasPureBufferSemantics()) {
1506 emitSilenceableError()
1507 <<
"cannot transform a linalg.copy on tensors into a memref.copy";
1508 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1514 assert(inputs.size() == 1 &&
"expected linalg copy op with one input");
1515 assert(outputs.size() == 1 &&
"expected memref copy op with one output");
1516 Value input = inputs.front();
1517 Value output = outputs.front();
1522 if (!isa<ShapedType>(input.
getType())) {
1524 emitSilenceableError()
1525 <<
"cannot transform a linalg.copy which input has no shape";
1526 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1531 assert(isa<ShapedType>(output.
getType()));
1533 if (cast<ShapedType>(input.
getType()).getElementType() !=
1534 cast<ShapedType>(output.
getType()).getElementType()) {
1536 emitSilenceableError()
1537 <<
"cannot transform a linalg.copy with different source and "
1538 "destination element types ";
1539 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1560 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1561 FailureOr<LowerPackResult> res =
1565 <<
"cannot lower to pad + expand + transpose";
1568 transformResults.
push_back(res->expandShapeOp);
1569 transformResults.
push_back(res->transposeOp);
1582 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1583 FailureOr<LowerUnPackOpResult> res =
1587 emitSilenceableError()
1588 <<
"cannot lower to transpose + collapse + extract";
1589 diag.attachNote(
target->getLoc()) <<
"target payload op";
1592 transformResults.
push_back(res->emptyOp);
1593 transformResults.
push_back(res->transposeOp);
1594 transformResults.
push_back(res->collapseShapeOp);
1595 transformResults.
push_back(res->extractSliceOp);
1596 transformResults.
push_back(res->copyOp);
1607 result.addAttribute(MatchOp::getOpsAttrName(
result.name),
1616 result.addAttribute(MatchOp::getOpsAttrName(
result.name),
1618 result.addTypes(resultTypes);
1626 if (getOps().has_value())
1627 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1630 if (!llvm::hasSingleElement(payloadOps)) {
1635 bool incorrectNumOperandTypes =
false;
1642 if (getInterface().has_value()) {
1643 auto iface = getInterface().value();
1644 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1647 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1648 !isa<TilingInterface>(op))
1650 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1651 !isa<LoopLikeOpInterface>(op))
1656 if (getOpAttrs().has_value()) {
1657 DictionaryAttr opAttrs = getOpAttrs().value();
1659 if (attr.getName() == getInterfaceAttrName() ||
1660 attr.getName() == getOpsAttrName())
1662 if (!op->
hasAttr(attr.getName()))
1664 if (op->
getAttr(attr.getName()) != attr.getValue())
1669 if (getFilterResultType().has_value()) {
1670 Type t = getFilterResultType().value();
1675 if (getFilterOperandTypes().has_value()) {
1676 mlir::ArrayAttr types = getFilterOperandTypes().value();
1679 if (types.size() == 1) {
1682 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1683 Type t = cast<::mlir::Type>(typeattr.getValue());
1685 [&](
Type operandType) { return operandType == t; }))
1690 if (types.size() != operandTypes.size()) {
1691 incorrectNumOperandTypes =
true;
1695 for (
auto [attr, operandType] :
1696 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1697 auto typeattr = cast<mlir::TypeAttr>(attr);
1698 Type type = cast<::mlir::Type>(typeattr.getValue());
1700 if (type != operandType)
1711 (*payloadOps.begin())->walk(matchFun);
1712 if (incorrectNumOperandTypes)
1714 "type, then it must contain as much types as "
1715 "the number of operands in the target ops");
1716 results.
set(cast<OpResult>(getResult()), res);
1731 Type &targetType,
Type &lowSizeType,
1733 Type &splitPointType) {
1734 FunctionType funcType;
1736 if (failed(parser.
parseType<FunctionType>(funcType)))
1739 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1740 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1741 "argument and one result";
1743 targetType = funcType.getInput(0);
1744 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1752 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1753 if (
target.hasDynamicShape()) {
1754 auto diag = emitSilenceableError()
1755 <<
"cannot compute parametric tile sizes for dynamically "
1756 "shaped payload op";
1757 diag.attachNote(
target->getLoc()) <<
"payload op";
1762 target, getDimension(), getTargetSize(), getDivisor());
1764 return emitSilenceableError()
1765 <<
"failed to compute multi-size tiling sizes";
1769 results.
assign(llvm::map_range(
1771 spec->lowTileSize * spec->lowTripCount}),
1772 [&builder,
this](
int64_t value) {
1784 builder,
target, getDimension(), targetSize, divisor);
1786 return emitSilenceableError() <<
"could not generate tile size computation";
1793 {spec->lowTileSize, spec->lowTripCount});
1794 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1795 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1796 assert(lowTileSize && highTileSize && splitPoint &&
1797 "tile sizes are not produced by operations");
1805void transform::MultiTileSizesOp::getEffects(
1809 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1815LogicalResult transform::MultiTileSizesOp::verify() {
1818 return emitOpError() <<
"expects all results type to be the same";
1837 Type linalgOpHType = transform::OperationType::get(
1838 builder.
getContext(), GenericOp::getOperationName());
1857 if (std::empty(targetOps)) {
1858 transformResults.
set(cast<OpResult>(getPackedOp()),
1863 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1864 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1865 return emitSilenceableError()
1866 <<
"requires target to map to exactly 1 LinalgOp (got "
1867 << llvm::range_size(targetOps) <<
")";
1870 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1871 return emitSilenceableError()
1872 <<
"requires number of packed sizes match the number of loops ("
1873 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1880 state, *
this, packedSizes, getMixedPackedSizes());
1883 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1887 transformResults.
set(cast<OpResult>(getPackedOp()),
1888 {maybeResult->packedLinalgOp.getOperation()});
1892void transform::PackOp::getEffects(
1904LogicalResult transform::PackGreedilyOp::verify() {
1906 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1907 <<
" is not a valid permutation";
1910 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1911 for (
auto [s, nmo] :
1912 llvm::zip_equal(getMixedMatmulPackedSizes(),
1913 getMatmulPaddedSizesNextMultipleOf())) {
1916 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1917 return emitOpError() <<
"at most one of the packed_size and the "
1918 "padded_sizes_next_multiple_of can be nonzero "
1919 "for the matmul strategy";
1932 auto linalgOp = dyn_cast<LinalgOp>(op);
1943 getMixedMatmulPackedSizes(),
1945 getMatmulPaddedSizesNextMultipleOf(),
1946 getMatmulInnerDimsOrder());
1947 if (succeeded(packResult)) {
1948 results.push_back(packResult->packedLinalgOp);
1951 results.push_back(linalgOp);
1953 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1959 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1963void transform::PackGreedilyOp::getEffects(
1975LogicalResult transform::PackTransposeOp::verify() {
1978 <<
" is not a valid permutation";
1982 <<
" is not a valid permutation";
1984 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1985 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1986 <<
" or " << getOuterPermAttrName()
1987 <<
" must be specified";
1993enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
2003template <
typename RelayoutOpTy>
2004static bool isValidPackingPermutation(
2006 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
2008 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
2009 "applies to only pack or unpack operations");
2010 if (!op || permutation.empty())
2012 size_t innerRank = op.getInnerDimsPos().size();
2013 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
2017 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
2018 return permutation.size() == op.getSourceRank() &&
2021 return permutation.size() == op.getDestRank() &&
2029 auto packOrUnpackOps = state.
getPayloadOps(getTargetPackOrUnPackOp());
2032 if (std::empty(packOrUnpackOps)) {
2033 transformResults.
set(cast<OpResult>(getPackedOp()), {});
2034 transformResults.
set(cast<OpResult>(getPackOp()), {});
2035 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
2041 if (!llvm::hasSingleElement(packOrUnpackOps) ||
2042 !llvm::hasSingleElement(linalgOps)) {
2043 return emitSilenceableError()
2044 <<
"requires target to map to exactly 1 "
2045 "packing op and 1 packed op ("
2046 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
2047 << llvm::range_size(linalgOps) <<
")";
2051 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
2052 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
2053 if ((!packOp && !unPackOp)) {
2054 return emitSilenceableError() <<
"requires target to map to a "
2055 "linalg.pack or linalg.unpack";
2057 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
2058 if (!linalgOpTarget)
2059 return emitSilenceableError() <<
"requires a LinalgOp target";
2063 if (packOp && packOp.getResult().hasOneUse())
2064 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
2066 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
2067 if (linalgOp != linalgOpTarget) {
2069 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
2070 : StringLiteral{
"not produced by the LinalgOp target"};
2071 return emitSilenceableError() << errorMsg;
2077 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
2078 OpOperand *packUse = linalgOp.getDpsInitOperand(
2079 cast<OpResult>(unPackOp.getSource()).getResultNumber());
2081 if (!packOp || !packOp.getResult().hasOneUse())
2082 return emitSilenceableError() <<
"could not find matching pack op";
2086 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
2088 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
2089 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
2090 ? StringLiteral{
"invalid outer_perm"}
2091 : StringLiteral{
"invalid inner_perm"};
2092 if (!isValidPackingPermutation(packOp, perm, permType) ||
2093 !isValidPackingPermutation(unPackOp, perm, permType)) {
2095 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
2096 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
2102 assert(packOp && linalgOp &&
"unexpected null op");
2106 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
2108 assert(succeeded(res) &&
"unexpected packTranspose failure");
2111 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
2112 transformResults.
set(cast<OpResult>(getPackedOp()),
2113 {res->transposedLinalgOp});
2115 transformResults.
set(cast<OpResult>(getUnPackOp()),
2116 {res->transposedUnPackOp});
2118 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
2133 StringRef copyBackOp,
2134 bool usePrescribedTensorShapes) {
2135 auto resultType = transform::AnyOpType::get(
b.getContext());
2141 b.getI64ArrayAttr(paddingDimensions),
2144 (padToMultipleOf.empty()
2146 :
b.getDenseI64ArrayAttr(padToMultipleOf)),
2147 b.getI64ArrayAttr(nofoldFlags),
2148 b.getArrayAttr(transposePaddings),
2149 b.getStringAttr(copyBackOp),
2151 usePrescribedTensorShapes ?
b.getUnitAttr() :
nullptr);
2159 StringRef copyBackOp,
2160 bool usePrescribedTensorShapes) {
2161 auto resultType = transform::AnyOpType::get(
b.getContext());
2165 staticPadToMultipleOf);
2171 b.getI64ArrayAttr(paddingDimensions),
2172 dynamicPadToMultipleOf,
2173 staticPadToMultipleOf,
2174 b.getI64ArrayAttr(nofoldFlags),
2175 b.getArrayAttr(transposePaddings),
2177 usePrescribedTensorShapes);
2180void PadOp::getEffects(
2188SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
2190 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(),
b);
2193DiagnosedSilenceableFailure
2194transform::PadOp::apply(transform::TransformRewriter &rewriter,
2195 transform::TransformResults &results,
2196 transform::TransformState &state) {
2197 auto transformOp = cast<TransformOpInterface>(getOperation());
2198 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
2201 auto linalgTarget = dyn_cast<LinalgOp>(
target);
2202 if (!linalgTarget) {
2203 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
2204 diag.attachNote(
target->getLoc()) <<
"target op";
2209 SmallVector<bool> nofoldFlags;
2210 for (int64_t packPadding :
2212 nofoldFlags.push_back(
static_cast<bool>(packPadding));
2215 SmallVector<Attribute> paddingValues;
2216 for (
auto const &[untypedAttr, elementOrTensorType] :
2217 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
2220 paddingValues.push_back(untypedAttr);
2223 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2225 emitOpError(
"expects padding values to be typed attributes or poison");
2230 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2234 if (!parsedAttr || parsedAttr.getType() != elementType) {
2236 << elementType <<
", got " << untypedAttr;
2237 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2240 paddingValues.push_back(parsedAttr);
2244 if (attr.getType() != elementType) {
2246 << elementType <<
", got " << attr;
2247 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2250 paddingValues.push_back(attr);
2254 SmallVector<SmallVector<int64_t>> transposePaddings;
2255 for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2257 cast<ArrayAttr>(transposeVector)));
2264 SmallVector<int64_t> padToMultipleOf;
2266 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2269 if (padToMultipleOf.empty())
2271 SmallVector<int64_t>(
options.paddingDimensions.size(), 1);
2273 options.padToMultipleOf = padToMultipleOf;
2274 options.paddingValues = paddingValues;
2275 options.nofoldFlags = nofoldFlags;
2276 if (getCopyBackOp() ==
2277 bufferization::MaterializeInDestinationOp::getOperationName()) {
2278 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
2279 BufferizationMaterializeInDestination;
2280 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2281 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
2282 }
else if (getCopyBackOp() == kCopyOpNone) {
2283 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
2285 llvm_unreachable(
"unsupported copy_back op");
2288 bool irChanged =
false;
2289 if (getUsePrescribedTensorShapes() &&
2290 linalgTarget.hasPureTensorSemantics()) {
2291 OpBuilder::InsertionGuard g(rewriter);
2293 for (OpOperand &operand : linalgTarget->getOpOperands()) {
2294 for (
auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2295 if (ShapedType::isStatic(dim))
2297 options.setSizeToPadTo(operand.getOperandNumber(), i,
2299 operand.get().getLoc(),
2306 SmallVector<Value> replacements;
2307 SmallVector<tensor::PadOp> newPadOps;
2309 replacements, newPadOps))) {
2315 auto diag = emitSilenceableError() <<
"failed to pad op";
2316 diag.attachNote(
target->getLoc()) <<
"target op";
2325 rewriter.
replaceOp(linalgTarget, replacements);
2326 paddedOps.push_back(paddedOp);
2327 padOps.append(newPadOps.begin(), newPadOps.end());
2328 if (
options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2329 for (Value v : replacements) {
2330 Operation *copyBackOp = v.getDefiningOp();
2331 if (!llvm::is_contained(copyBackOps, copyBackOp))
2332 copyBackOps.push_back(copyBackOp);
2337 results.
set(cast<OpResult>(getPadded()), paddedOps);
2338 results.
set(cast<OpResult>(getPad()), padOps);
2339 results.
set(cast<OpResult>(getCopy()), copyBackOps);
2343LogicalResult transform::PadOp::verify() {
2344 SmallVector<int64_t> nofoldFlags =
2346 if (any_of(nofoldFlags, [](int64_t packPadding) {
2347 return packPadding != 0 && packPadding != 1;
2350 <<
"expects nofold_flags to contain booleans (0/1), found "
2351 << getNofoldFlags();
2354 SmallVector<int64_t> paddingDimensions =
2356 if (any_of(paddingDimensions,
2357 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
2358 return emitOpError() <<
"expects padding_dimensions to contain positive "
2360 << getPaddingDimensions();
2362 if (!getMixedPadToMultipleOf().empty()) {
2363 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2364 return emitOpError() <<
"expects as many multiples as padding_dimensions";
2367 ArrayAttr transposes = getTransposePaddings();
2368 for (Attribute attr : transposes) {
2370 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2371 if (!std::is_permutation(sequence.begin(), sequence.end(),
2372 transpose.begin(), transpose.end())) {
2374 <<
"expects transpose_paddings to be a permutation, found "
2378 if (getCopyBackOp() !=
2379 bufferization::MaterializeInDestinationOp::getOperationName() &&
2380 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2381 getCopyBackOp() != kCopyOpNone)
2390void transform::PadTilingInterfaceOp::build(OpBuilder &
b,
2393 ArrayRef<int64_t> paddingSizes,
2394 bool padToMultipleOf) {
2395 auto resultType = transform::AnyOpType::get(
b.getContext());
2404 :
b.getDenseI64ArrayAttr(paddingSizes)),
2406 padToMultipleOf ?
b.getUnitAttr() :
nullptr);
2409void transform::PadTilingInterfaceOp::build(
2411 ArrayRef<OpFoldResult> mixedPaddingSizes,
bool padToMultipleOf) {
2412 auto resultType = transform::AnyOpType::get(
b.getContext());
2413 SmallVector<int64_t> staticPaddingSizes;
2414 SmallVector<Value> dynamicPaddingSizes;
2416 staticPaddingSizes);
2422 dynamicPaddingSizes,
2427void transform::PadTilingInterfaceOp::getEffects(
2428 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2435SmallVector<OpFoldResult>
2436transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2441DiagnosedSilenceableFailure
2442transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2443 transform::TransformResults &results,
2444 transform::TransformState &state) {
2445 SmallVector<Operation *> paddedOps, padOps;
2448 auto targetOp = dyn_cast<TilingInterface>(
target);
2450 auto diag = emitSilenceableError() <<
"expected TilingInterface target";
2451 diag.attachNote(
target->getLoc()) <<
"target op";
2458 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2459 auto diag = emitSilenceableError() <<
"only IndexingMapOpInterface ops "
2461 diag.attachNote(
target->getLoc()) <<
"target op";
2466 SmallVector<Attribute> paddingValues;
2467 for (
auto const &[untypedAttr, elementOrTensorType] :
2468 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2469 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2473 paddingValues.push_back(untypedAttr);
2477 emitOpError(
"expects padding values to be typed attributes or poison");
2481 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2485 if (!parsedAttr || parsedAttr.getType() != elementType) {
2487 << elementType <<
", got " << attr;
2488 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2491 paddingValues.push_back(parsedAttr);
2495 if (attr.getType() != elementType) {
2497 << elementType <<
", got " << attr;
2498 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2501 paddingValues.push_back(attr);
2505 PadTilingInterfaceOptions
options;
2506 options.setPaddingValues(paddingValues)
2507 .setPaddingSizes(getMixedPaddingSizes())
2508 .setPadToMultipleOf(getPadToMultipleOf());
2510 OpBuilder::InsertionGuard g(rewriter);
2513 rewriter, cast<TilingInterface>(targetOp.getOperation()),
options);
2514 if (
failed(maybePadOps)) {
2515 auto diag = emitSilenceableError() <<
"failed to pad op";
2516 diag.attachNote(
target->getLoc()) <<
"target op";
2519 const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
2522 paddedOps.push_back(paddedOp);
2523 padOps.append(paddedOperands.begin(), paddedOperands.end());
2524 rewriter.
replaceOp(targetOp.getOperation(), slicedResults);
2527 results.
set(cast<OpResult>(getPadded()), paddedOps);
2528 results.
set(cast<OpResult>(getPad()), padOps);
2532LogicalResult transform::PadTilingInterfaceOp::verify() {
return success(); }
2538DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2539 transform::TransformRewriter &rewriter,
2540 transform::TransformResults &transformResults,
2541 transform::TransformState &state) {
2544 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2546 <<
"requires exactly one target and one loop handle (got "
2547 << llvm::range_size(targetOps) <<
" and "
2548 << llvm::range_size(loopOps) <<
")";
2551 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2552 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2553 if (!padOp || !loopOp)
2556 FailureOr<linalg::detail::PackingResult>
result =
2562 if (
result->clonedLoopIvs.empty()) {
2563 transformResults.
set(cast<OpResult>(getPackingLoop()),
2564 {
result->hoistedPadOp.getOperation()});
2567 auto outerPackedLoop =
2569 transformResults.
set(cast<OpResult>(getPackingLoop()),
2570 {outerPackedLoop.getOperation()});
2574LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
2575 ArrayRef<int64_t> transpose = getTranspose();
2576 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2577 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2579 return emitOpError() <<
"expects transpose to be a permutation, found "
2585void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2586 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2593DiagnosedSilenceableFailure
2594transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2596 transform::ApplyToEachResultList &results,
2597 transform::TransformState &state) {
2598 tensor::PadOp hoistedPadOp;
2599 SmallVector<TransposeOp> transposeOps;
2600 FailureOr<Value>
result =
2602 hoistedPadOp, transposeOps);
2613 return emitDefaultSilenceableFailure(
target);
2616LogicalResult transform::HoistPadOp::verify() {
2617 ArrayRef<int64_t> transpose = getTranspose();
2618 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2619 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2621 return emitOpError() <<
"expects transpose to be a permutation, found "
2631DiagnosedSilenceableFailure
2632transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2634 transform::ApplyToEachResultList &results,
2635 transform::TransformState &state) {
2636 LinalgPromotionOptions promotionOptions;
2637 if (!getOperandsToPromote().empty())
2640 if (getUseFullTilesByDefault())
2642 getUseFullTilesByDefault());
2643 if (getUseOriginalSubviewSize())
2647 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2648 if (!getUseFullTileBuffers().empty())
2650 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2651 if (getAlignment().has_value())
2652 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2653 if (getMemorySpace().has_value())
2654 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2656 if (getMapping().has_value()) {
2658 auto mapping = *getMapping();
2659 if (mapping.size() > 1)
2660 return emitDefaultDefiniteFailure(
target);
2662 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2664 if (addressSpace.getAddressSpace() ==
2665 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2672 }
else if (addressSpace.getAddressSpace() ==
2673 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2681 return emitDefaultDefiniteFailure(
target);
2686 return emitDefaultDefiniteFailure(
target);
2691 return emitDefaultDefiniteFailure(
target);
2700DiagnosedSilenceableFailure
2701transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2702 TransformResults &transformResults,
2703 TransformState &state) {
2707 for (Operation *
target : payload) {
2708 if (
target->getNumOperands() > 0)
2710 if (!
target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2711 target->getNumRegions() > 0)
2713 <<
"expected target that is isolated from above";
2717 Operation *pattern = &getBodyRegion().front().front();
2718 SmallVector<Operation *> replacements;
2719 for (Operation *
target : payload) {
2720 if (getOperation()->isAncestor(
target))
2727 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2731void transform::ReplaceOp::getEffects(
2732 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2738LogicalResult transform::ReplaceOp::verify() {
2739 if (!getBodyRegion().hasOneBlock())
2741 if (std::distance(getBodyRegion().front().begin(),
2742 getBodyRegion().front().end()) != 1)
2743 return emitOpError() <<
"expected one operation in block";
2744 Operation *
replacement = &getBodyRegion().front().front();
2747 <<
"expected replacement without operands";
2748 if (!
replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2751 <<
"expect op that is isolated from above";
2759DiagnosedSilenceableFailure
2760transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2762 transform::ApplyToEachResultList &results,
2763 transform::TransformState &state) {
2764 scf::SCFTilingOptions tilingOptions;
2765 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &
b, Operation *) {
2766 SmallVector<OpFoldResult> tileSizes;
2767 Location loc =
target.getLoc();
2768 SmallVector<OpFoldResult> allShapeSizes =
2769 target.createFlatListOfOperandDims(
b, loc);
2770 AffineMap map =
target.getShapesToLoopsMap();
2773 SmallVector<OpFoldResult> shapeSizes =
2778 for (OpFoldResult shapeSize : shapeSizes) {
2780 :
b.getIndexAttr(1));
2785 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2786 rewriter, cast<TilingInterface>(
target.getOperation()), tilingOptions);
2787 if (
failed(maybeTilingResult))
2788 return emitDefaultDefiniteFailure(
target);
2790 if (
target->getNumResults())
2795 results.
reserve(maybeTilingResult->tiledOps.size());
2796 for (Operation *tiled : maybeTilingResult->tiledOps)
2805DiagnosedSilenceableFailure
2806transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2807 transform::TransformResults &results,
2808 transform::TransformState &state) {
2809 SmallVector<Operation *> loops;
2811 auto tilingOp = dyn_cast<TilingInterface>(*
target);
2813 DiagnosedSilenceableFailure
diag =
2814 emitSilenceableError()
2815 <<
"expected the payload to implement TilingInterface";
2816 diag.attachNote(
target->getLoc()) <<
"payload op";
2820 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2821 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2822 if (
failed(generatedLoops))
2823 return emitDefaultDefiniteFailure(
target);
2824 for (scf::ForOp &loop : *generatedLoops) {
2825 loops.push_back(loop.getOperation());
2829 results.
set(cast<OpResult>(getResult()), loops);
2837DiagnosedSilenceableFailure
2838transform::RewriteInDestinationPassingStyleOp::applyToOne(
2839 transform::TransformRewriter &rewriter, Operation *
target,
2840 transform::ApplyToEachResultList &results,
2841 transform::TransformState &state) {
2843 FailureOr<Operation *> maybeResult =
2845 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2846 [&rewriter](
auto op) {
2850 return emitDefaultSilenceableFailure(
target);
2859DiagnosedSilenceableFailure
2860SplitOp::apply(transform::TransformRewriter &rewriter,
2861 TransformResults &results, TransformState &state) {
2863 SmallVector<Operation *> payload =
2866 bool isMultiwaySplit = getMultiway();
2868 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2870 <<
"requires exactly one target when "
2871 "multiway split is enabled (got "
2872 << llvm::range_size(payload) <<
")";
2875 SmallVector<OpFoldResult> chunkSizes;
2877 if (!isMultiwaySplit)
2878 chunkSizes.reserve(payload.size());
2880 if (getDynamicChunkSizes()) {
2882 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2883 chunkSizes = llvm::map_to_vector(
2884 state.
getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2887 diag = emitSilenceableError()
2888 <<
"expected dynamic split point handle to point to a "
2889 "single-result index-typed op";
2890 diag.attachNote(op->
getLoc()) <<
"dynamic split point";
2895 chunkSizes = llvm::map_to_vector(
2896 state.
getParams(getDynamicChunkSizes()),
2897 [](Attribute attr) {
return OpFoldResult(attr); });
2899 if (
diag.isSilenceableFailure())
2904 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2906 <<
"expected the dynamic split point handle to point to as "
2908 << chunkSizes.size() <<
") as the target handle ("
2909 << payload.size() <<
")";
2912 chunkSizes.resize(payload.size(),
2916 auto checkStructuredOpAndDimensions =
2917 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2919 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2920 diag.attachNote(loc) <<
"target op";
2924 if (getDimension() >= linalgOp.getNumLoops()) {
2925 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2926 <<
" does not exist in target op";
2927 diag.attachNote(loc) <<
"target op";
2933 auto checkFailureInSplitting =
2934 [&](
bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2943 SmallVector<Operation *> opList;
2944 if (isMultiwaySplit) {
2947 TilingInterface head, tail;
2948 Operation *
target = payload.front();
2950 LinalgOp linalgOp = dyn_cast<LinalgOp>(
target);
2953 DiagnosedSilenceableFailure
diag =
2954 checkStructuredOpAndDimensions(linalgOp,
target->getLoc());
2955 if (
diag.isSilenceableFailure())
2958 for (
auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2961 target = tail.getOperation();
2966 linalgOp = cast<LinalgOp>(
target);
2967 Location loc =
target->getLoc();
2971 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2972 getDimension(), chunkSize);
2975 DiagnosedSilenceableFailure
diag =
2976 checkFailureInSplitting(!head && !tail, loc);
2977 if (
diag.isDefiniteFailure())
2980 opList.push_back(head.getOperation());
2985 opList.push_back(tail.getOperation());
2989 SmallVector<Operation *> first, second;
2990 Operation *noSecondPart =
nullptr;
2991 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2992 Operation *
target = std::get<0>(pair);
2993 Location loc =
target->getLoc();
2994 LinalgOp linalgOp = dyn_cast<LinalgOp>(
target);
2995 DiagnosedSilenceableFailure
diag =
2996 checkStructuredOpAndDimensions(linalgOp,
target->getLoc());
2998 if (
diag.isSilenceableFailure())
3002 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
3003 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
3004 getDimension(), std::get<1>(pair));
3007 DiagnosedSilenceableFailure diagSplit =
3008 checkFailureInSplitting(!first.back() && !second.back(), loc);
3013 if (!second.back()) {
3019 if (second.size() != first.size() && !second.empty()) {
3020 auto diag = emitSilenceableError()
3021 <<
"splitting does not produce the second part for a subset "
3024 <<
"expected splitting to produce the second part of all "
3025 "or none of the targets";
3027 <<
"first target with no second part";
3031 opList.append(first);
3032 if (!second.empty())
3033 opList.append(second);
3035 results.
set(cast<OpResult>(getSplitList()), opList);
3039void SplitOp::getEffects(
3040 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3042 if (getDynamicChunkSizes())
3048ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &
result) {
3049 OpAsmParser::UnresolvedOperand
target, dynamicChunkSizes;
3050 IntegerAttr staticChunkSizes;
3054 OptionalParseResult dynamicPointParseResult =
3056 if (!dynamicPointParseResult.
has_value()) {
3057 int64_t staticChunkSizesValue;
3071 if (dynamicPointParseResult.
has_value()) {
3072 Type chunkSizesType;
3085 SplitOp::getStaticChunkSizesAttrName(
result.name).getValue(),
3087 result.addTypes(targetType);
3091void SplitOp::print(OpAsmPrinter &printer) {
3092 printer <<
" " << getTarget() <<
" after ";
3093 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
3094 if (staticChunkSize != ShapedType::kDynamic)
3095 printer << staticChunkSize;
3097 printer << getDynamicChunkSizes();
3100 {getStaticChunkSizesAttrName()});
3101 printer <<
" : " << getTarget().getType();
3102 if (staticChunkSize == ShapedType::kDynamic)
3103 printer <<
", " << getDynamicChunkSizes().getType();
3106LogicalResult SplitOp::verify() {
3107 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
3108 (getDynamicChunkSizes() ==
nullptr)) {
3109 return emitOpError() <<
"expects either a dynamic or a static split "
3110 "point to be provided";
3119void transform::SplitReductionOp::build(
3120 OpBuilder &builder, OperationState &
result, Value
target,
3121 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
3122 bool useScalingAlgorithm,
bool useAlloc) {
3125 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(
result.name),
3128 SplitReductionOp::getInsertSplitDimensionAttrName(
result.name),
3130 if (innerParallel) {
3131 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(
result.name),
3134 if (useScalingAlgorithm) {
3136 SplitReductionOp::getUseScalingAlgorithmAttrName(
result.name),
3140 result.addAttribute(SplitReductionOp::getUseAllocAttrName(
result.name),
3143 auto resultType = transform::AnyOpType::get(ctx);
3144 result.addTypes({resultType, resultType, resultType, resultType});
3147DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
3148 transform::TransformRewriter &rewriter, LinalgOp
target,
3149 transform::ApplyToEachResultList &results,
3150 transform::TransformState &state) {
3152 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
3153 unsigned(getInsertSplitDimension()),
3154 bool(getInnerParallel())};
3157 FailureOr<SplitReductionResult> splitResult =
3158 (getUseScalingAlgorithm())
3162 return emitDefaultDefiniteFailure(
target);
3164 results.
push_back(splitResult->initOrAlloc);
3166 results.
push_back(splitResult->splitLinalgOp);
3167 results.
push_back(splitResult->resultCombiningLinalgOp);
3175void transform::TileReductionUsingForOp::build(
3176 OpBuilder &builder, OperationState &
result, Value
target,
3177 ArrayRef<int64_t> staticTileSizes) {
3184 auto opTy = transform::AnyOpType::get(ctx);
3190 staticTileSizesAttr);
3193DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
3194 transform::TransformRewriter &rewriter, Operation *
target,
3195 transform::ApplyToEachResultList &results,
3196 transform::TransformState &state) {
3199 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(
target);
3200 if (!partialReductionOp) {
3203 "Operation should implement PartialReductionOpInterface");
3206 SmallVector<unsigned> reductionDims =
3208 if (reductionDims.empty()) {
3209 for (
auto [idx, iteratorType] :
3210 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3211 if (iteratorType == utils::IteratorType::reduction)
3212 reductionDims.push_back(idx);
3216 scf::SCFTilingOptions
options;
3217 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
3218 options.setReductionTilingStrategy(
3221 options.setReductionDims(reductionDims);
3222 FailureOr<scf::SCFTilingResult>
result =
3223 scf::tileUsingSCF(rewriter, partialReductionOp,
options);
3227 "failed to tile using partial reduction");
3230 for (Value initValue :
result->initialValues)
3232 for (
auto *parallelTiledOp :
result->tiledOps)
3234 for (
auto *mergeOp :
result->mergeOps)
3244void transform::TileReductionUsingForallOp::build(
3245 OpBuilder &builder, OperationState &
result, Value
target,
3246 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3254 auto opTy = transform::AnyOpType::get(ctx);
3261 staticNumThreadsAttr,
3262 staticTileSizesAttr,
3266DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3267 transform::TransformRewriter &rewriter, Operation *
target,
3268 transform::ApplyToEachResultList &results,
3269 transform::TransformState &state) {
3272 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(
target);
3273 if (!partialReductionOp) {
3276 "Operation should implement PartialReductionOpInterface");
3278 SmallVector<OpFoldResult> numThreads =
3280 SmallVector<OpFoldResult> tileSizes =
3283 scf::SCFTilingOptions
options;
3284 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3285 options.setReductionTilingStrategy(
3287 if (!getNumThreads().empty()) {
3288 options.setNumThreads(numThreads);
3290 options.setTileSizes(tileSizes);
3292 if (
auto mapping = getMapping()) {
3293 options.setMapping(mapping.value().getValue());
3295 SmallVector<unsigned> reductionDims =
3297 if (reductionDims.empty()) {
3298 for (
auto [idx, iteratorType] :
3299 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3300 if (iteratorType == utils::IteratorType::reduction)
3301 reductionDims.push_back(idx);
3304 options.setReductionDims(reductionDims);
3305 FailureOr<scf::SCFTilingResult>
result =
3306 scf::tileUsingSCF(rewriter, partialReductionOp,
options);
3309 auto diag = emitSilenceableError() <<
"could not tile reduction";
3314 for (Value initValue :
result->initialValues)
3316 for (
auto *parallelTiledOp :
result->tiledOps)
3318 for (
auto *mergeOp :
result->mergeOps)
3328DiagnosedSilenceableFailure
3329transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3330 TransformResults &transformResults,
3331 TransformState &state) {
3333 SmallVector<Operation *> targetOps =
3336 if (!llvm::hasSingleElement(targetOps)) {
3338 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
3342 Operation *
target = *targetOps.begin();
3343 auto linalgOp = dyn_cast<LinalgOp>(
target);
3344 auto tileableOp = dyn_cast<TilingInterface>(
target);
3349 OpBuilder builder(linalgOp.getContext());
3351 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
3352 if (linalgOp.hasDynamicShape()) {
3353 auto diag = emitSilenceableError()
3354 <<
"cannot compute parametric tile sizes for dynamically "
3355 "shaped payload op";
3356 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
3360 FailureOr<StaticContinuousTileSizeSpecification> spec =
3364 return emitSilenceableError()
3365 <<
"failed to compute multi-size tiling sizes";
3368 SmallVector<int64_t> chunkSizes;
3370 for (
auto &&[tileSize, tripCount] :
3371 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3372 chunkSizes.push_back(tileSize * tripCount);
3374 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3375 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3380 getI64AttrsFromI64(spec->tileSizes));
3381 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
3382 getI64AttrsFromI64(chunkSizes));
3389 OpFoldResult targetSize = builder.
getIndexAttr(getTargetSize());
3390 unsigned dimension = getDimension();
3393 builder, tileableOp, dimension, targetSize,
true);
3395 return emitSilenceableError() <<
"could not generate tile size computation";
3400 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3405 SmallVector<Value> chunkSizes;
3407 for (
auto &&[tileSize, tripCount] :
3408 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3409 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3410 chunkSizes.push_back(splitPoint);
3413 auto getDefiningOps = [&](ArrayRef<Value> values) {
3414 return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3420 getDefiningOps(spec->tileSizes));
3421 transformResults.
set(cast<OpResult>(getChunkSizes()),
3422 getDefiningOps(chunkSizes));
3427LogicalResult transform::ContinuousTileSizesOp::verify() {
3430 return emitOpError() <<
"expects all results type to be the same";
3436void transform::ContinuousTileSizesOp::getEffects(
3437 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3454 Type &tileSizesType,
3455 Type &chunkSizesType) {
3456 FunctionType funcType;
3458 if (failed(parser.
parseType<FunctionType>(funcType)))
3461 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3462 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
3463 "argument and one result";
3465 targetType = funcType.getInput(0);
3466 tileSizesType = chunkSizesType = funcType.getResult(0);
3475void transform::TileUsingForOp::build(
3477 Value
target, ArrayRef<int64_t> staticTileSizes,
3478 ArrayRef<int64_t> interchange,
3479 std::optional<ArrayRef<bool>> scalableSizes) {
3480 return build(builder,
result, loopTypes,
3484 interchange, scalableSizes);
3487void transform::TileUsingForOp::build(
3488 OpBuilder &builder, OperationState &
result, Value
target,
3489 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3490 std::optional<ArrayRef<bool>> scalableSizes) {
3493 interchange, scalableSizes);
3496void transform::TileUsingForOp::build(
3497 OpBuilder &builder, OperationState &
result, Value
target,
3498 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3499 std::optional<ArrayRef<bool>> scalableSizes) {
3502 SmallVector<Type> loopTypes(1, builder.
getType<transform::AnyOpType>());
3503 build(builder,
result, loopTypes,
target, mixedTileSizes, interchange,
3507void transform::TileUsingForOp::build(
3509 Value
target, ArrayRef<OpFoldResult> mixedTileSizes,
3510 ArrayRef<int64_t> interchange,
3511 std::optional<ArrayRef<bool>> scalableSizes) {
3512 SmallVector<int64_t> staticTileSizes;
3513 SmallVector<Value> dynamicTileSizes;
3519 unsigned numExpectedLoops =
3520 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3521 SmallVector<Type> resultTypes;
3522 resultTypes.reserve(numExpectedLoops);
3523 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3524 "expected one loop type or as many as loops");
3525 if (loopTypes.size() == 1)
3526 resultTypes.append(numExpectedLoops, loopTypes[0]);
3528 llvm::append_range(resultTypes, loopTypes);
3529 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(),
false);
3530 if (scalableSizes.has_value())
3531 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3536 staticTileSizesAttr,
3538 expandedScalableSizes);
3541LogicalResult transform::TileUsingForOp::verify() {
3543 return emitOpError(
"expected same number of sizes (")
3545 << getScalableSizes().size() <<
")";
3546 ArrayRef<int64_t> staticSizes = getStaticSizes();
3547 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3548 if (getLoops().size() != numExpectedLoops)
3549 return emitOpError(
"expected number of loops to tile (")
3550 << numExpectedLoops <<
") to match number of `loops` results ("
3551 << getLoops().size() <<
")";
3555DiagnosedSilenceableFailure
3556transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3557 TransformResults &transformResults,
3558 TransformState &state) {
3559 ArrayRef<int64_t> tileSizes = getStaticSizes();
3561 SmallVector<Operation *> targets =
3563 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3564 SmallVector<SmallVector<int64_t>> paramSizes;
3568 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
3569 dynamicSizeProducers.push_back({});
3570 ArrayRef<Attribute> params = state.
getParams(transformValue);
3571 paramSizes.push_back(llvm::map_to_vector(params, [](Attribute attr) {
3572 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3575 if (paramSizes.back().size() != targets.size()) {
3576 DiagnosedSilenceableFailure
diag =
3577 emitSilenceableError()
3578 <<
"expected as many parameter values ("
3579 << dynamicSizeProducers.back().size() <<
") as target ops ("
3580 << targets.size() <<
")";
3581 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
3587 paramSizes.push_back({});
3588 dynamicSizeProducers.push_back(
3591 if (dynamicSizeProducers.back().size() != targets.size()) {
3592 DiagnosedSilenceableFailure
diag =
3593 emitSilenceableError()
3594 <<
"expected as many dynamic size-producing operations ("
3595 << dynamicSizeProducers.back().size() <<
") as target ops ("
3596 << targets.size() <<
")";
3597 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3601 for (Operation *op : dynamicSizeProducers.back()) {
3607 DiagnosedSilenceableFailure
diag =
3608 emitSilenceableError() <<
"expected sizes to be produced by ops "
3609 "with a single index-type result";
3610 diag.attachNote(op->
getLoc()) <<
"size producer op";
3611 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3616 SmallVector<Operation *> tiled;
3617 SmallVector<SmallVector<Operation *, 4>, 4> loops;
3618 loops.resize(getLoops().size());
3619 auto scalableSizes = getScalableSizes();
3620 for (
auto [i, op] : llvm::enumerate(targets)) {
3621 auto tilingInterface = dyn_cast<TilingInterface>(op);
3622 if (!tilingInterface) {
3623 DiagnosedSilenceableFailure
diag =
3624 emitSilenceableError()
3625 <<
"only ops implementing TilingInterface are supported";
3626 diag.attachNote(op->
getLoc()) <<
"target op";
3629 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3630 DiagnosedSilenceableFailure
diag =
3631 emitSilenceableError()
3632 <<
"too many tiles provided, expected at most "
3633 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
3634 << tileSizes.size();
3635 diag.attachNote(op->
getLoc()) <<
"target op";
3639 scf::SCFTilingOptions tilingOptions;
3640 if (tileSizes.empty()) {
3641 tilingOptions.setTileSizeComputationFunction(
3642 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
3646 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &
b,
3648 SmallVector<OpFoldResult> sizes;
3649 sizes.reserve(tileSizes.size());
3650 unsigned dynamicIdx = 0;
3652 for (
auto [ofrIdx, ofr] : llvm::enumerate(
getMixedSizes())) {
3653 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3654 if (scalableSizes[ofrIdx]) {
3656 b, getLoc(), cast<IntegerAttr>(attr).getInt());
3658 vector::VectorScaleOp::create(
b, getLoc(),
b.getIndexType());
3660 arith::MulIOp::create(
b, getLoc(), val, vscale).getResult());
3662 sizes.push_back(attr);
3666 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3667 ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3669 assert((dynamicSizes.empty() ^ params.empty()) &&
3670 "expected either dynamic sizes or parameters");
3671 if (!params.empty()) {
3672 sizes.push_back(
b.getIndexAttr(params[index]));
3674 sizes.push_back(dynamicSizes[index]->getResult(0));
3681 tilingOptions.setInterchange(getInterchange());
3682 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3683 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3684 if (
failed(maybeTilingResult))
3687 rewriter.
replaceOp(op, maybeTilingResult->replacements);
3689 tiled.append(maybeTilingResult->tiledOps);
3690 for (
const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3691 loops[en2.index()].push_back(en2.value());
3694 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3695 for (
const auto &en : llvm::enumerate(loops))
3696 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3701SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
3703 ArrayRef<int64_t> tileSizes = getStaticSizes();
3704 SmallVector<OpFoldResult> results;
3705 results.reserve(tileSizes.size());
3706 unsigned dynamicPos = 0;
3708 for (int64_t size : tileSizes) {
3709 if (size == ShapedType::kDynamic) {
3710 results.push_back(dynamic[dynamicPos++]);
3718void transform::TileUsingForOp::getEffects(
3719 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3730void transform::TileUsingForallOp::build(OpBuilder &builder,
3732 ArrayRef<int64_t> staticTileSizes,
3733 transform::TileSizesSpec,
3735 return build(builder,
result,
3743void transform::TileUsingForallOp::build(OpBuilder &builder,
3745 ArrayRef<OpFoldResult> mixedTileSizes,
3746 transform::TileSizesSpec,
3748 SmallVector<int64_t> staticTileSizes;
3749 SmallVector<Value> dynamicTileSizes;
3755 auto operationType = transform::AnyOpType::get(ctx);
3758 TypeRange{operationType, operationType},
3765 staticTileSizesAttr,
3769void transform::TileUsingForallOp::build(OpBuilder &builder,
3771 ArrayRef<int64_t> staticNumThreads,
3772 transform::NumThreadsSpec,
3776 NumThreadsSpec(), mapping);
3779void transform::TileUsingForallOp::build(OpBuilder &builder,
3781 ArrayRef<OpFoldResult> mixedNumThreads,
3782 transform::NumThreadsSpec,
3784 SmallVector<int64_t> staticNumThreads;
3785 SmallVector<Value> dynamicNumThreads;
3792 auto operationType = transform::AnyOpType::get(ctx);
3795 TypeRange{operationType, operationType},
3801 staticNumThreadsAttr,
3808static SmallVector<OpFoldResult>
3814 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3816 for (
auto [lb,
ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3818 rewriter, loc, normalizedUbExpr, {lb,
ub, step});
3819 normalizedUbs.push_back(normalizedUb);
3821 return normalizedUbs;
3837 for (
auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3840 denormalizedIvs.push_back(
3843 return denormalizedIvs;
3854 scf::ForallOp loop) {
3871 auto normalizedForallOp = scf::ForallOp::create(
3872 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3873 loop.getOutputs(), loop.getMapping(),
3876 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3878 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3883 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3884 normalizedForallOp.getRegionIterArgs().end());
3885 Block *origLoopBlock = loop.getBody();
3886 rewriter.
mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3888 rewriter.
replaceOp(loop, normalizedForallOp);
3889 return normalizedForallOp;
3897 scf::SCFTilingResult &tilingResult) {
3899 auto tileableOp = dyn_cast<TilingInterface>(
target);
3902 transformOp.emitSilenceableError()
3903 <<
"only TilingInterface ops are supported";
3904 diag.attachNote(
target->getLoc()) <<
"target op";
3908 scf::SCFTilingOptions
options;
3909 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3910 if (!mixedNumThreads.empty()) {
3911 options.setNumThreads(mixedNumThreads);
3913 options.setTileSizes(mixedTileSizes);
3916 options.setMapping(mapping.value().getValue());
3918 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3919 scf::tileUsingSCF(rewriter, tileableOp,
options);
3921 if (failed(maybeTilingResult))
3922 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3924 rewriter.
replaceOp(tileableOp, maybeTilingResult->replacements);
3926 tilingResult = *maybeTilingResult;
3930 if (mixedNumThreads.empty() && !tilingResult.loops.empty()) {
3931 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3934 scf::ForallOp normalizedForallOp =
3936 tilingResult.loops.front() = normalizedForallOp;
3946 auto transformOp = cast<TransformOpInterface>(getOperation());
3955 getPackedNumThreads()
3957 state, transformOp, mixedNumThreads, getPackedNumThreads())
3959 state, transformOp, mixedNumThreads, getMixedNumThreads());
3963 status = getPackedTileSizes()
3965 state, transformOp, mixedTileSizes, getPackedTileSizes())
3967 state, transformOp, mixedTileSizes, getMixedTileSizes());
3972 scf::SCFTilingResult tilingResult;
3974 rewriter, state, transformOp,
target, mixedNumThreads, mixedTileSizes,
3975 getMapping(), tilingResult);
3976 if (!
diag.succeeded())
3978 if (!tilingResult.loops.empty())
3979 tileOps.push_back(tilingResult.loops.front());
3980 tiledOps.append(tilingResult.tiledOps);
3983 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3984 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3989void transform::TileUsingForallOp::getEffects(
3990 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4000SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
4005SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
4010LogicalResult TileUsingForallOp::verify() {
4011 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
4012 static_cast<int>(getPackedNumThreads() != Value());
4013 if (numThreadsSpec > 1)
4015 "num_threads and packed_num_threads are mutually exclusive");
4016 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
4017 static_cast<int>(getPackedTileSizes() != Value());
4018 if (tileSizesSpec > 1)
4020 "tile_sizes and packed_tile_sizes are mutually exclusive");
4021 if (numThreadsSpec == 0 && tileSizesSpec == 0)
4022 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
4023 "must be specified");
4031void transform::VectorizeChildrenAndApplyPatternsOp::build(
4032 OpBuilder &builder, OperationState &
result, Value
target,
4033 bool foldTypeExtensionsIntoContract,
bool vectorizePadding,
4034 bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
4036 if (foldTypeExtensionsIntoContract) {
4038 VectorizeChildrenAndApplyPatternsOp::
4039 getFoldTypeExtensionsIntoContractAttrName(
result.name),
4042 if (vectorizePadding) {
4044 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
4048 if (vectorizeExtract) {
4050 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
4054 if (flatten1DDepthwiseConv) {
4056 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
4066struct VectorizationPattern :
public RewritePattern {
4067 explicit VectorizationPattern(MLIRContext *context,
4068 bool vectorizeExtract =
false,
4069 bool flattenConv =
false)
4070 : RewritePattern(MatchAnyOpTypeTag(), 1, context),
4071 vectorizeNDExtract(vectorizeExtract),
4072 flatten1DDepthwiseConv(flattenConv) {}
4073 LogicalResult matchAndRewrite(Operation *op,
4074 PatternRewriter &rewriter)
const override {
4077 "Unsupported Op, cannot vectorize");
4078 FailureOr<VectorizationResult> vectorResults =
4080 {}, vectorizeNDExtract,
4081 flatten1DDepthwiseConv);
4082 if (
failed(vectorResults))
4084 rewriter.
replaceOp(op, vectorResults->replacements);
4091 bool vectorizeNDExtract =
false;
4095 bool flatten1DDepthwiseConv =
false;
4099DiagnosedSilenceableFailure
4100transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
4101 transform::TransformRewriter &rewriter, Operation *
target,
4102 transform::ApplyToEachResultList &results,
4103 transform::TransformState &state) {
4104 if (!
target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
4105 auto diag = this->
emitOpError(
"requires isolated-from-above targets");
4106 diag.attachNote(
target->getLoc()) <<
"non-isolated target";
4111 RewritePatternSet patterns(ctx);
4112 patterns.
add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
4113 getFlatten_1dDepthwiseConv());
4115 if (!getDisableTransferPermutationMapLoweringPatterns())
4118 if (!getDisableMultiReductionToContractPatterns())
4123 patterns.
add<linalg::LinalgCopyVTRForwardingPattern,
4124 linalg::LinalgCopyVTWForwardingPattern>(ctx,
4126 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
4127 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
4130 patterns.
add<CopyVectorizationPattern>(ctx);
4132 if (getFoldTypeExtensionsIntoContract())
4135 if (getVectorizePadding()) {
4143 TrackingListener listener(state, *
this);
4146 GreedyRewriteConfig().setListener(&listener))))
4147 return emitDefaultDefiniteFailure(
target);
4157DiagnosedSilenceableFailure transform::VectorizeOp::apply(
4158 transform::TransformRewriter &rewriter,
4159 mlir::transform::TransformResults &transformResults,
4160 mlir::transform::TransformState &state) {
4162 if (std::empty(targets))
4164 auto transformOp = cast<TransformOpInterface>(getOperation());
4165 SmallVector<int64_t> vectorSizes;
4167 state, transformOp, getMixedVectorSizes(), vectorSizes);
4172 for (Operation *
target : targets) {
4175 <<
"Unsupported Op, cannot vectorize";
4177 FailureOr<VectorizationResult> vectorResults =
4179 getVectorizeNdExtract().value_or(
false),
4181 getAssumeDynamicDimsMatchVecSizes().value_or(
false),
4182 getCreateNamedContraction().value_or(
false));
4183 if (
failed(vectorResults)) {
4185 <<
"Attempted to vectorize, but failed";
4193void transform::VectorizeOp::getEffects(
4194 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4200SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
4205LogicalResult transform::VectorizeOp::verify() {
4206 if (getStaticVectorSizes().size() != getScalableSizes().size())
4207 return emitOpError(
"expected same number of vector sizes (")
4208 << getStaticVectorSizes().size() <<
") and scalable sizes ("
4209 << getScalableSizes().size() <<
")";
4217DiagnosedSilenceableFailure
4218transform::HoistRedundantVectorTransfersOp::applyToOne(
4219 transform::TransformRewriter &rewriter, func::FuncOp
target,
4220 transform::ApplyToEachResultList &results,
4221 transform::TransformState &state) {
4234DiagnosedSilenceableFailure
4235transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4236 transform::TransformRewriter &rewriter, mlir::Operation *
target,
4237 transform::ApplyToEachResultList &results,
4238 transform::TransformState &state) {
4249DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
4250 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4251 transform::ApplyToEachResultList &results,
4252 transform::TransformState &state) {
4254 auto maybeTransformed =
4257 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4260 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4263 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4266 .Case([&](linalg::Conv2DNchwFchwOp op) {
4269 .Default([&](Operation *op) {
4272 if (
failed(maybeTransformed))
4273 return emitDefaultSilenceableFailure(
target);
4275 results.
push_back(maybeTransformed->first);
4277 results.
push_back(maybeTransformed->second);
4285DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4286 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4287 transform::ApplyToEachResultList &results,
4288 transform::TransformState &state) {
4292 <<
"only elementwise flattening is supported";
4295 if (
target.getNumLoops() <= 1) {
4302 std::iota(reassociation.begin(), reassociation.end(), 0);
4303 auto maybeFlattened =
4305 if (
failed(maybeFlattened))
4307 <<
"attempted to flatten, but failed";
4308 results.
push_back(maybeFlattened->collapsedOp);
4317DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4318 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4319 transform::ApplyToEachResultList &results,
4320 transform::TransformState &state) {
4322 auto maybeTransformed =
4324 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4327 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4330 .Default([&](Operation *op) {
4333 if (
failed(maybeTransformed))
4334 return emitDefaultSilenceableFailure(
target);
4344DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4345 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4346 transform::ApplyToEachResultList &results,
4347 transform::TransformState &state) {
4349 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4350 auto maybeTransformed =
4352 .Case([&](linalg::MatmulOp op) {
4355 .Case([&](linalg::BatchMatmulOp op) {
4358 .Default(failure());
4359 if (
failed(maybeTransformed))
4369template <
typename OpTy>
4370static DiagnosedSilenceableFailure
4374 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4375 tensor::ParallelInsertSliceOp>() &&
4378 if (
auto copySource =
4379 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4387 if (isa<mlir::ParallelCombiningOpInterface>(
target.getOperation()))
4390 Value extracted = tensor::ExtractSliceOp::create(
4393 Value copied = linalg::CopyOp::create(rewriter,
target.getLoc(),
4394 target.getSource(), extracted)
4406DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4407 transform::TransformRewriter &rewriter, Operation *targetOp,
4408 transform::ApplyToEachResultList &results,
4409 transform::TransformState &state) {
4412 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4413 return doit(rewriter,
target, results, state);
4414 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4415 return doit(rewriter,
target, results, state);
4417 DiagnosedSilenceableFailure
diag =
4418 emitSilenceableError()
4419 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4420 diag.attachNote(targetOp->
getLoc()) <<
"target op";
4428DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4429 transform::TransformRewriter &rewriter, Operation *
target,
4430 transform::ApplyToEachResultList &results,
4431 transform::TransformState &state) {
4433 if (!isa<linalg::CopyOp, tensor::PadOp>(
target)) {
4434 DiagnosedSilenceableFailure
diag =
4435 emitSilenceableError()
4436 <<
"only linalg.copy and tensor.pad target ops are supported";
4437 diag.attachNote(
target->getLoc()) <<
"target op";
4440 assert(
target->getNumResults() == 1 &&
"expected single result");
4441 auto resultShapedType = cast<ShapedType>(
target->getResult(0).getType());
4442 if (!resultShapedType.hasStaticShape()) {
4443 DiagnosedSilenceableFailure
diag =
4444 emitSilenceableError()
4445 <<
"only statically sized ops of rank <= 3 are supported";
4446 diag.attachNote(
target->getLoc()) <<
"target op";
4451 int64_t desiredBitAlignment = getDesiredBitAlignment();
4452 int64_t eltBitwidth =
4453 resultShapedType.getElementType().getIntOrFloatBitWidth();
4454 if (desiredBitAlignment % eltBitwidth != 0) {
4455 desiredBitAlignment = eltBitwidth;
4458 gpu::CopyMappingInfo mapping(
4460 getTotalNumThreads(),
4461 desiredBitAlignment,
4462 resultShapedType.getShape(),
4465 resultShapedType.getElementType().getIntOrFloatBitWidth());
4466 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4467 DiagnosedSilenceableFailure
diag =
4468 emitSilenceableError()
4469 <<
"too few threads to map copy op to threads on the most minor "
4470 "dimension, given alignment and vector size constraints, try "
4471 "smaller tile size of mapping to more threads";
4472 diag.attachNote(
target->getLoc()) <<
"target op";
4478 scf::SCFTilingResult tilingResult;
4485 ArrayRef<OpFoldResult>{},
4486 b.getArrayAttr(mapping.threadMapping),
4488 if (!
diag.succeeded())
4491 results.
push_back(tilingResult.loops.front());
4492 for (
auto *op : tilingResult.tiledOps)
4501DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4502 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4503 transform::ApplyToEachResultList &results,
4504 transform::TransformState &state) {
4506 FailureOr<Operation *> maybeTransformed = failure();
4508 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4513 .Default([&](Operation *op) {
return false; });
4516 return emitSilenceableError()
4517 <<
"this operation is not supported to convert to Winograd Conv2D";
4520 if (
failed(maybeTransformed)) {
4521 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
4528DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4529 transform::TransformRewriter &rewriter, Operation *
target,
4530 transform::ApplyToEachResultList &results,
4531 transform::TransformState &state) {
4533 FailureOr<Operation *> maybeTransformed = failure();
4536 .Case([&](linalg::WinogradFilterTransformOp op) {
4540 .Case([&](linalg::WinogradInputTransformOp op) {
4544 .Case([&](linalg::WinogradOutputTransformOp op) {
4551 DiagnosedSilenceableFailure
diag =
4552 emitSilenceableError()
4553 <<
"this operation is not supported to decompose into other operations";
4554 diag.attachNote(
target->getLoc()) <<
"target op";
4558 if (
failed(maybeTransformed)) {
4559 DiagnosedSilenceableFailure
diag =
4560 emitSilenceableError() <<
"decompose Winograd operations failed";
4561 diag.attachNote(
target->getLoc()) <<
"target op";
4569#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4571#define GET_OP_CLASSES
4572#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static llvm::ManagedStatic< PassManagerOptions > options
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Base type for affine expression.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the error.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation, bool PoisonPaddingOk=false)
Patterns to bubble up or down data layout ops across other operations.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice + copy.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Canonicalization patterns relevant to apply after tiling patterns.
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors and memref.
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
FailureOr< LinalgOp > downscaleSizeOneWindowedConvolution(RewriterBase &rewriter, LinalgOp op)
Rewrite convolution/pooling/depthwise ops with size-1 window dimensions into lower-dimensional ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SmallVector< IntTy > extractFromIntegerArrayAttr(Attribute attr)
Extract integer values from the assumed ArrayAttr of IntegerAttr.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.