42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/ScopeExit.h"
44#include "llvm/ADT/SmallPtrSet.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/DebugLog.h"
48#include "llvm/Support/LogicalResult.h"
55#define DEBUG_TYPE "linalg-transforms"
62template <
typename PatternTy,
typename... Args>
65 using OpTy =
typename llvm::function_traits<
66 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
67 auto op = dyn_cast<OpTy>(operation);
72 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
77 auto result = pattern.returningMatchAndRewrite(op, rewriter);
80 return cast<LinalgOp>(
result->getOperation());
90 if (
auto attr = dyn_cast<Attribute>(ofr)) {
91 if (!isa<IntegerAttr>(attr))
92 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
97 Value transformValue = cast<Value>(ofr);
98 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
100 if (params.size() != 1)
101 return transformOp.emitDefiniteFailure()
102 <<
"requires exactly one parameter associated";
103 result.push_back(params[0]);
108 if (!llvm::hasSingleElement(payloadOps)) {
110 transformOp.emitSilenceableError()
111 <<
"handle must be mapped to exactly one payload op";
113 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
120 transformOp.emitSilenceableError()
121 <<
"payload op must have exactly 1 index result";
141 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
143 for (
auto param : params) {
144 if (!isa<IntegerAttr>(param))
145 return transformOp.emitDefiniteFailure()
146 <<
"expected the parameter to be associated with an integer "
154 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
156 transformOp.emitSilenceableError()
157 <<
"payload op must have exactly 1 index result";
158 diag.attachNote(op->getLoc())
159 <<
"has " << op->getNumResults() <<
" results";
162 result.push_back(op->getResult(0));
176 if (
auto attr = dyn_cast<Attribute>(paramOrHandle)) {
177 reified.push_back(cast<IntegerAttr>(attr).getInt());
180 if (isa<TransformParamTypeInterface>(
181 cast<Value>(paramOrHandle).
getType())) {
183 if (params.size() != 1)
184 return transformOp.emitSilenceableError() <<
"expected a single param";
186 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
190 Value handle = cast<Value>(paramOrHandle);
191 if (!isa<TransformHandleTypeInterface>(handle.getType()))
192 return transformOp.emitSilenceableError() <<
"unexpected value handle";
194 if (!llvm::hasSingleElement(payload))
195 return transformOp.emitSilenceableError()
196 <<
"requires param or handle that is mapped to 1 payload op";
198 Operation *paramOrHandlePayloadOp = *payload.begin();
201 return transformOp.emitSilenceableError()
202 <<
"requires param or handle to be result of op with 1 index "
208 return transformOp.emitSilenceableError()
209 <<
"requires param or handle to be the result of a constant like "
212 reified.push_back(attr.getInt());
221void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
226void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
231void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
236void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
242void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
245 options.rankReductionStrategy =
250void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
255void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
260void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
265void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
270void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
275void transform::ApplyDataLayoutPropagationPatternsOp::populatePatterns(
284void transform::ApplyExtractSliceSinkingPatternsOp::populatePatterns(
288 Operation *producer = opOperand->get().getDefiningOp();
289 Operation *consumer = opOperand->getOwner();
304 SmallVector<Operation *> getNewOps()
const {
305 return SmallVector<Operation *>(newOps.begin(), newOps.end());
309 void notifyOperationInserted(Operation *op,
310 OpBuilder::InsertPoint previous)
override {
311 ForwardingListener::notifyOperationInserted(op, previous);
313 if (previous.
isSet())
317 assert(
inserted.second &&
"expected newly created op");
320 void notifyOperationErased(Operation *op)
override {
321 ForwardingListener::notifyOperationErased(op);
322 op->
walk([&](Operation *op) { newOps.erase(op); });
334 llvm::scope_exit resetListener(
335 [&]() { rewriter.
setListener(previousListener); });
336 NewOpsListener newOpsListener(previousListener);
340 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
341 options.memcpyOp = linalg::BufferizeToAllocationOptions::MemcpyOp::
342 MaterializeInDestination;
343 }
else if (getMemcpyOp() ==
"memref.copy") {
346 }
else if (getMemcpyOp() ==
"linalg.copy") {
350 llvm_unreachable(
"invalid memcpy op");
352 if (getAllocOp() ==
"memref.alloc") {
355 }
else if (getAllocOp() ==
"memref.alloca") {
359 llvm_unreachable(
"invalid alloc op");
361 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
362 options.emitDealloc = getEmitDealloc();
366 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
373 <<
"failed to bufferize operation";
374 diag.attachNote(op->
getLoc()) <<
"target payload op";
377 allocatedBuffers.push_back(buffer);
381 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
382 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
386void transform::BufferizeToAllocationOp::getEffects(
388 if (getBufferizeDestinationOnly()) {
399LogicalResult transform::BufferizeToAllocationOp::verify() {
400 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
401 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
403 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
416 auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.
getOwner());
423 Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
431 if (!isa<TensorType, FloatType, IntegerType>(value.
getType()))
433 return llvm::any_of(value.
getUses(),
443 auto type = dyn_cast<RankedTensorType>(
tensor.getType());
445 return emitSilenceableError() <<
"non-tensor type: " <<
tensor;
459 for (
auto [pos, dim] : llvm::enumerate(type.getShape())) {
460 if (!ShapedType::isDynamic(dim))
465 tensor::DimOp::create(rewriter,
tensor.getLoc(),
tensor, cst);
466 preservedOps.insert(dimOp);
467 dynamicDims.push_back(dimOp);
469 auto allocation = bufferization::AllocTensorOp::create(
470 rewriter,
tensor.getLoc(), type, dynamicDims);
472 if (getMemorySpaceAttr())
473 allocation.setMemorySpaceAttr(getMemorySpaceAttr());
474 Value allocated = allocation;
478 if (needsMaterialization) {
479 auto copy = bufferization::MaterializeInDestinationOp::create(
481 preservedOps.insert(
copy);
482 promoted.push_back(
copy.getResult());
484 promoted.push_back(allocated);
488 results.
setValues(cast<OpResult>(getPromoted()), promoted);
492void transform::PromoteTensorOp::getEffects(
508 FailureOr<linalg::LinalgOp> res =
510 if (succeeded(res)) {
514 return emitDefaultSilenceableFailure(
target);
528 auto decomposableOp = dyn_cast<AggregatedOpInterface>(
target);
529 if (!decomposableOp) {
531 "payload is not a decomposable op"));
532 return emitDefaultSilenceableFailure(
target);
535 FailureOr<SmallVector<Value>> maybeNewResults =
536 decomposableOp.decomposeOperation(rewriter);
537 if (
failed(maybeNewResults))
538 return emitDefaultSilenceableFailure(
target);
540 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
541 for (
Value val : *maybeNewResults) {
542 Operation *definition = val.getDefiningOp();
553void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
560transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
564 options.allowReturnAllocsFromLoops =
true;
570 <<
"failed to analyze op";
572 rewriter,
target, state)))
574 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
587 bool applyCleanup,
bool useForall) {
589 builder,
result, loopTypes,
595 applyCleanup, useForall);
601 bool applyCleanup,
bool useForall) {
609 applyCleanup, useForall);
616 bool applyCleanup,
bool useForall) {
620 build(builder,
result, loopTypes,
target, mixedTileSizes,
621 mixedTileInterchange, applyCleanup, useForall);
628 bool applyCleanup,
bool useForall) {
635 staticTileInterchange);
640 auto staticTileInterchangeAttr =
642 unsigned numExpectedLoops =
643 useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
645 resultTypes.reserve(numExpectedLoops);
646 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
647 "expected one loop type or as many as loops");
648 if (loopTypes.size() == 1)
649 resultTypes.append(numExpectedLoops, loopTypes[0]);
651 llvm::append_range(resultTypes, loopTypes);
656 dynamicTileInterchange,
658 staticTileInterchangeAttr,
665template <
typename Range>
669 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
675 auto tilingInterfaceOp = dyn_cast<TilingInterface>(
target);
676 if (!tilingInterfaceOp)
677 return transformOp->
emitError(
"only TilingInterface ops are supported");
680 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
681 applyFn(tilingInterfaceOp);
682 if (failed(tiledResults))
687 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
688 for (
Operation *toReplace : opsToReplace) {
689 for (
OpResult res : toReplace->getResults())
690 if (
auto replacement = tiledResults->replacements.lookup(res))
692 if (toReplace->use_empty()) {
698 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
699 assert(tiledResults->loops.size() == numLoops &&
700 "Mismatched number of loops, tile and fuse transform should have "
702 for (
unsigned int i = 0; i < numLoops; ++i)
703 loopOps[i].
push_back(tiledResults->loops[i]);
706 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
707 for (
unsigned int i = 0; i < numLoops; ++i)
708 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
717 auto transformOp = cast<TransformOpInterface>(getOperation());
721 state, transformOp, getMixedTileSizes(), tileSizes);
726 state, transformOp, getMixedTileInterchange(), tileInterchange);
730 scf::SCFTilingOptions tilingOptions;
731 tilingOptions.interchangeVector = tileInterchange;
732 bool useForall = getUseForall();
733 tilingOptions.setLoopType(useForall
734 ? scf::SCFTilingOptions::LoopType::ForallOp
735 : scf::SCFTilingOptions::LoopType::ForOp);
738 tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
739 scf::SCFTileAndFuseOptions tileAndFuseOptions;
740 tileAndFuseOptions.tilingOptions = tilingOptions;
742 if (getApplyCleanup()) {
745 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
748 tileAndFuseOptions.cleanupPatterns = std::move(patterns);
752 useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
754 rewriter, getOperation(), state.
getPayloadOps(getTarget()), numLoops,
756 [&](TilingInterface tilingInterfaceOp)
757 -> FailureOr<scf::SCFTileAndFuseResult> {
758 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
765LogicalResult transform::FuseOp::verify() {
766 auto iterspace_rank = getStaticTileSizes().size();
768 if (permutation.size() > iterspace_rank)
770 <<
"interchange length exceeds iteration space dimensions ("
771 << iterspace_rank <<
"), found " << getTileInterchange();
773 for (
int64_t v : permutation) {
774 if (!ShapedType::isDynamic(v)) {
775 if (v < 0 || v >=
static_cast<int64_t>(iterspace_rank))
776 return emitOpError() <<
"expects interchange values to be in range [0, "
777 << iterspace_rank <<
"), found: " << v;
779 return emitOpError() <<
"found duplicate interchange value: " << v;
785 size_t numExpectedLoops =
786 getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
787 if (numExpectedLoops != getNumResults() - 1)
788 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
798 return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
802void transform::FuseOp::getEffects(
815void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
818 Value containingOp) {
819 result.addOperands({producerOp, containingOp});
820 auto resultType = transform::AnyOpType::get(builder.
getContext());
821 result.addTypes({resultType, resultType});
837 (domInfo.
dominates(containingOp, user))) {
838 dominatedUsers.insert(user);
841 if (dominatedUsers.empty())
845 auto forallOp = cast<scf::ForallOp>(containingOp);
851 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
856 newOuts.push_back(outputs[resultNumber]);
859 auto newforallOp = scf::ForallOp::create(
860 rewriter, loc, forallOp.getMixedLowerBound(),
861 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
862 forallOp.getMapping());
864 newforallOp.getRegion().takeBody(forallOp.getRegion());
869 newforallOp.getBody()->addArgument(newOuts.back().getType(),
870 newOuts.back().getLoc());
871 auto bbArgs = newforallOp.getBody()->getArguments();
874 Operation *op = use.getOwner();
875 return newforallOp->isProperAncestor(op);
879 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
881 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; });
882 Operation *firstYieldOp = yieldingOps.front();
885 Value dst = newforallOp.getRegionIterArgs().back();
887 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->
getLoc(), src,
888 dst, offsets, sizes, strides);
890 for (
auto result : llvm::enumerate(forallOp.getResults())) {
892 newforallOp->getResult(
result.index()));
895 newforallOp->getResults().back(),
897 Operation *user = use.getOwner();
898 return dominatedUsers.contains(user);
912 destWorklist.push_back(dst);
914 while (!destWorklist.empty()) {
915 Value currentDst = destWorklist.pop_back_val();
919 if (src == currentDst)
924 auto bbArg = dyn_cast<BlockArgument>(currentDst);
928 Block *parentBlock = bbArg.getOwner();
929 assert(parentBlock &&
"unlinked block argument");
932 assert(parentOp &&
"expected block argument with parent operation");
935 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
939 for (
auto innerIterArg : parentLoop.getRegionIterArgs()) {
941 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
942 Value loopBlockArgument =
944 destWorklist.push_back(loopBlockArgument);
957static std::tuple<SmallVector<Operation *>,
Operation *>
960 LDBG() <<
"Try to fuse a direct extract use";
961 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
962 if (!tileableProducer) {
964 <<
"producer is not a TileableInterface: " << *producerOp;
971 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
972 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
973 return sliceOp && containingOp->isProperAncestor(sliceOp);
977 if (it == tileableProducer->getUsers().end()) {
978 diag.attachNote(tileableProducer->getLoc())
979 <<
"could not find fusion opportunity for: " << *tileableProducer;
982 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
995 if (LoopLikeOpInterface containerLoop =
996 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
1002 auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(
clone);
1006 for (
OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
1007 Value producerOperand =
1008 clone->getOperand(initOperandPtr.getOperandNumber());
1010 containerLoop.getRegionIterArgs()) {
1011 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
1012 Value consumerOperand =
1016 initOperandPtr.set(containerIterArg);
1022 tileableProducer = dyn_cast<TilingInterface>(
clone);
1027 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
1028 LDBG() <<
"resultNumber: " << resultNumber;
1033 FailureOr<TilingResult> tileAndFuseResult =
1034 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
1037 if (failed(tileAndFuseResult)) {
1038 diag.attachNote(tileableProducer->getLoc())
1039 <<
"failed to tile producer op: " << *tileableProducer;
1044 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
1045 LDBG() <<
"tiledProducer: " << *tiledOp;
1050 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1051 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1052 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1053 if (failed(maybeRankReduced)) {
1055 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
1056 << tileAndFuseResult->tiledValues[0]
1057 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
1060 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
1064 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
1065 resultNumber, offsets, sizes);
1068 if (isa<LoopLikeOpInterface>(containingOp))
1069 rewriter.
eraseOp(tileableProducer);
1071 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
1084 LDBG() <<
"Try to fuse an extract use through block argument";
1086 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
1087 if (!tileableProducer) {
1089 <<
"producer is not a TileableInterface: " << *producerOp;
1094 scf::ForallOp forallOp;
1095 auto itProducerUses =
1096 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
1097 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
1101 if (!forallOp || forallOp != containingOp) {
1102 diag.attachNote(tileableProducer->getLoc())
1103 <<
"could not find a use by the containing op: " << *tileableProducer;
1118 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1119 return sliceOp && containingOp->isProperAncestor(sliceOp);
1123 if (itBBArgUsers == bbArg.
getUsers().end()) {
1125 <<
"could not find fusion opportunity for bbArg: " << bbArg;
1128 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
1136 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
1137 LDBG() <<
"resultNumber: " << resultNumber;
1142 rewriter, tileableProducer->getLoc(), tileableProducer,
1143 destinationTensors))) {
1144 diag.attachNote(tileableProducer->getLoc())
1145 <<
"failed to get destination tensors for: " << *tileableProducer;
1150 bvm.
map(destinationTensors[resultNumber], bbArg);
1151 auto tileableProducerClone =
1152 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
1153 llvm::scope_exit scopeGuard(
1154 [&]() { rewriter.
eraseOp(tileableProducerClone); });
1157 FailureOr<TilingResult> tileAndFuseResult =
1158 tileableProducerClone.generateResultTileValue(
1159 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
1160 sliceOpToTile.getMixedSizes());
1161 if (failed(tileAndFuseResult)) {
1162 diag.attachNote(tileableProducer->getLoc())
1163 <<
"failed to tile producer op: " << *tileableProducer;
1168 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1169 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1170 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1171 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
1172 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
1177 destinationTensors.front());
1180 return tileAndFuseResult->tiledOps;
1186 LDBG() <<
"Try to fuse an use by cloning";
1193 uses.push_back(&use);
1198 if (containingOp == use.getOwner()) {
1200 <<
"producer op use by containing op cannot be fused by cloning";
1208 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
1217 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
1218 "Parallel insert slice is not a valid clone destination");
1219 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
1220 LDBG() <<
"resultNumber: " << resultNumber;
1224 fusedOp = rewriter.
clone(*producerOp);
1226 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1231bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1242 auto containingOps = state.
getPayloadOps(getContainingOp());
1243 if (!llvm::hasSingleElement(containingOps)) {
1245 <<
"requires exactly one containing_op handle (got "
1246 << llvm::range_size(containingOps) <<
")";
1248 Operation *containingOp = *containingOps.begin();
1251 if (std::empty(producerOps)) {
1253 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1260 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1261 for (
const auto &it :
enumerate(remainingProducers)) {
1264 int64_t numUsesInContainingOp =
1266 return containingOp->isAncestor(op);
1271 if (numUsesInContainingOp > 0) {
1272 if (numUsesInContainingOp == 1)
1273 remainingProducers.erase(remainingProducers.begin() + it.index());
1280 while (!remainingProducers.empty()) {
1281 auto nextProducer = getNextProducer();
1282 if (
failed(nextProducer)) {
1284 <<
"could not find next producer to fuse into container";
1285 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
1293 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
1300 auto [tiledOps, newContainingOp] =
1302 if (!tiledOps.empty()) {
1303 LDBG() <<
"\nFused a direct extract use\n" << *containingOp;
1304 fusedOps.append(tiledOps);
1305 if (newContainingOp) {
1313 LogicalResult replacementStatus =
1316 (
void)replacementStatus;
1317 assert(succeeded(replacementStatus) &&
1318 "unable to update transform state mapping");
1319 rewriter.
eraseOp(containingOp);
1320 containingOp = newContainingOp;
1327 rewriter,
diag, producerOp, containingOp);
1328 if (!tiledContainingOpOperand.empty()) {
1329 LDBG() <<
"\nFused an extract use through block argument\n"
1331 fusedOps.append(tiledContainingOpOperand);
1338 LDBG() <<
"\nFused an use by cloning\n" << *containingOp;
1339 fusedOps.push_back(cloned);
1345 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1346 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1350void transform::FuseIntoContainingOp::getEffects(
1368 if (isa<GenericOp>(
target)) {
1374 if (succeeded(generic)) {
1375 results.
push_back(generic->getOperation());
1378 return emitDefaultSilenceableFailure(
target);
1391 if (!isa<GenericOp>(
target)) {
1398 FailureOr<LinalgOp> named =
1400 if (succeeded(named)) {
1401 results.
push_back(named->getOperation());
1404 return emitDefaultSilenceableFailure(
target);
1418 if (interchangeVector.empty()) {
1423 unsigned numLoops = cast<LinalgOp>(
target.getOperation()).getNumLoops();
1424 if (interchangeVector.size() != numLoops) {
1425 return emitSilenceableError()
1426 << getIteratorInterchangeAttrName() <<
" has length ("
1427 << interchangeVector.size()
1428 <<
") different from the number of loops in the target operation ("
1439LogicalResult transform::InterchangeOp::verify() {
1441 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1442 if (!std::is_permutation(sequence.begin(), sequence.end(),
1443 permutation.begin(), permutation.end())) {
1445 <<
"expects iterator_interchange to be a permutation, found "
1446 << getIteratorInterchange();
1461 if (!isa<linalg::CopyOp>(targetOp)) {
1463 emitSilenceableError() <<
"only linalg.copy target ops are supported";
1464 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1468 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1469 if (!copyOp.hasPureBufferSemantics()) {
1471 emitSilenceableError()
1472 <<
"cannot transform a linalg.copy on tensors into a memref.copy";
1473 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1479 assert(inputs.size() == 1 &&
"expected linalg copy op with one input");
1480 assert(outputs.size() == 1 &&
"expected memref copy op with one output");
1481 Value input = inputs.front();
1482 Value output = outputs.front();
1487 if (!isa<ShapedType>(input.
getType())) {
1489 emitSilenceableError()
1490 <<
"cannot transform a linalg.copy which input has no shape";
1491 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1496 assert(isa<ShapedType>(output.
getType()));
1498 if (cast<ShapedType>(input.
getType()).getElementType() !=
1499 cast<ShapedType>(output.
getType()).getElementType()) {
1501 emitSilenceableError()
1502 <<
"cannot transform a linalg.copy with different source and "
1503 "destination element types ";
1504 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1525 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1526 FailureOr<LowerPackResult> res =
1530 <<
"cannot lower to pad + expand + transpose";
1533 transformResults.
push_back(res->expandShapeOp);
1534 transformResults.
push_back(res->transposeOp);
1547 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1548 FailureOr<LowerUnPackOpResult> res =
1552 emitSilenceableError()
1553 <<
"cannot lower to transpose + collapse + extract";
1554 diag.attachNote(
target->getLoc()) <<
"target payload op";
1557 transformResults.
push_back(res->emptyOp);
1558 transformResults.
push_back(res->transposeOp);
1559 transformResults.
push_back(res->collapseShapeOp);
1560 transformResults.
push_back(res->extractSliceOp);
1561 transformResults.
push_back(res->copyOp);
1572 result.addAttribute(MatchOp::getOpsAttrName(
result.name),
1581 result.addAttribute(MatchOp::getOpsAttrName(
result.name),
1583 result.addTypes(resultTypes);
1591 if (getOps().has_value())
1592 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1595 if (!llvm::hasSingleElement(payloadOps)) {
1600 bool incorrectNumOperandTypes =
false;
1607 if (getInterface().has_value()) {
1608 auto iface = getInterface().value();
1609 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1612 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1613 !isa<TilingInterface>(op))
1615 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1616 !isa<LoopLikeOpInterface>(op))
1621 if (getOpAttrs().has_value()) {
1622 DictionaryAttr opAttrs = getOpAttrs().value();
1624 if (attr.getName() == getInterfaceAttrName() ||
1625 attr.getName() == getOpsAttrName())
1627 if (!op->
hasAttr(attr.getName()))
1629 if (op->
getAttr(attr.getName()) != attr.getValue())
1634 if (getFilterResultType().has_value()) {
1635 Type t = getFilterResultType().value();
1640 if (getFilterOperandTypes().has_value()) {
1641 mlir::ArrayAttr types = getFilterOperandTypes().value();
1644 if (types.size() == 1) {
1647 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1648 Type t = cast<::mlir::Type>(typeattr.getValue());
1650 [&](
Type operandType) { return operandType == t; }))
1655 if (types.size() != operandTypes.size()) {
1656 incorrectNumOperandTypes =
true;
1660 for (
auto [attr, operandType] :
1661 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1662 auto typeattr = cast<mlir::TypeAttr>(attr);
1663 Type type = cast<::mlir::Type>(typeattr.getValue());
1665 if (type != operandType)
1676 (*payloadOps.begin())->walk(matchFun);
1677 if (incorrectNumOperandTypes)
1679 "type, then it must contain as much types as "
1680 "the number of operands in the target ops");
1681 results.
set(cast<OpResult>(getResult()), res);
1696 Type &targetType,
Type &lowSizeType,
1698 Type &splitPointType) {
1699 FunctionType funcType;
1701 if (failed(parser.
parseType<FunctionType>(funcType)))
1704 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1705 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1706 "argument and one result";
1708 targetType = funcType.getInput(0);
1709 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1717 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1718 if (
target.hasDynamicShape()) {
1719 auto diag = emitSilenceableError()
1720 <<
"cannot compute parametric tile sizes for dynamically "
1721 "shaped payload op";
1722 diag.attachNote(
target->getLoc()) <<
"payload op";
1727 target, getDimension(), getTargetSize(), getDivisor());
1729 return emitSilenceableError()
1730 <<
"failed to compute multi-size tiling sizes";
1734 results.
assign(llvm::map_range(
1736 spec->lowTileSize * spec->lowTripCount}),
1737 [&builder,
this](
int64_t value) {
1749 builder,
target, getDimension(), targetSize, divisor);
1751 return emitSilenceableError() <<
"could not generate tile size computation";
1758 {spec->lowTileSize, spec->lowTripCount});
1759 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1760 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1761 assert(lowTileSize && highTileSize && splitPoint &&
1762 "tile sizes are not produced by operations");
1770void transform::MultiTileSizesOp::getEffects(
1774 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1780LogicalResult transform::MultiTileSizesOp::verify() {
1783 return emitOpError() <<
"expects all results type to be the same";
1802 Type linalgOpHType = transform::OperationType::get(
1803 builder.
getContext(), GenericOp::getOperationName());
1822 if (std::empty(targetOps)) {
1823 transformResults.
set(cast<OpResult>(getPackedOp()),
1828 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1829 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1830 return emitSilenceableError()
1831 <<
"requires target to map to exactly 1 LinalgOp (got "
1832 << llvm::range_size(targetOps) <<
")";
1835 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1836 return emitSilenceableError()
1837 <<
"requires number of packed sizes match the number of loops ("
1838 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1845 state, *
this, packedSizes, getMixedPackedSizes());
1848 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1852 transformResults.
set(cast<OpResult>(getPackedOp()),
1853 {maybeResult->packedLinalgOp.getOperation()});
1857void transform::PackOp::getEffects(
1869LogicalResult transform::PackGreedilyOp::verify() {
1871 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1872 <<
" is not a valid permutation";
1875 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1876 for (
auto [s, nmo] :
1877 llvm::zip_equal(getMixedMatmulPackedSizes(),
1878 getMatmulPaddedSizesNextMultipleOf())) {
1881 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1882 return emitOpError() <<
"at most one of the packed_size and the "
1883 "padded_sizes_next_multiple_of can be nonzero "
1884 "for the matmul strategy";
1897 auto linalgOp = dyn_cast<LinalgOp>(op);
1908 getMixedMatmulPackedSizes(),
1910 getMatmulPaddedSizesNextMultipleOf(),
1911 getMatmulInnerDimsOrder());
1912 if (succeeded(packResult)) {
1913 results.push_back(packResult->packedLinalgOp);
1916 results.push_back(linalgOp);
1918 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1924 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1928void transform::PackGreedilyOp::getEffects(
1940LogicalResult transform::PackTransposeOp::verify() {
1943 <<
" is not a valid permutation";
1947 <<
" is not a valid permutation";
1949 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1950 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1951 <<
" or " << getOuterPermAttrName()
1952 <<
" must be specified";
1958enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1968template <
typename RelayoutOpTy>
1969static bool isValidPackingPermutation(
1971 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1973 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1974 "applies to only pack or unpack operations");
1975 if (!op || permutation.empty())
1977 size_t innerRank = op.getInnerDimsPos().size();
1978 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1982 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1983 return permutation.size() == op.getSourceRank() &&
1986 return permutation.size() == op.getDestRank() &&
1994 auto packOrUnpackOps = state.
getPayloadOps(getTargetPackOrUnPackOp());
1997 if (std::empty(packOrUnpackOps)) {
1998 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1999 transformResults.
set(cast<OpResult>(getPackOp()), {});
2000 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
2006 if (!llvm::hasSingleElement(packOrUnpackOps) ||
2007 !llvm::hasSingleElement(linalgOps)) {
2008 return emitSilenceableError()
2009 <<
"requires target to map to exactly 1 "
2010 "packing op and 1 packed op ("
2011 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
2012 << llvm::range_size(linalgOps) <<
")";
2016 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
2017 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
2018 if ((!packOp && !unPackOp)) {
2019 return emitSilenceableError() <<
"requires target to map to a "
2020 "linalg.pack or linalg.unpack";
2022 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
2023 if (!linalgOpTarget)
2024 return emitSilenceableError() <<
"requires a LinalgOp target";
2028 if (packOp && packOp.getResult().hasOneUse())
2029 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
2031 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
2032 if (linalgOp != linalgOpTarget) {
2034 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
2035 : StringLiteral{
"not produced by the LinalgOp target"};
2036 return emitSilenceableError() << errorMsg;
2042 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
2043 OpOperand *packUse = linalgOp.getDpsInitOperand(
2044 cast<OpResult>(unPackOp.getSource()).getResultNumber());
2046 if (!packOp || !packOp.getResult().hasOneUse())
2047 return emitSilenceableError() <<
"could not find matching pack op";
2051 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
2053 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
2054 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
2055 ? StringLiteral{
"invalid outer_perm"}
2056 : StringLiteral{
"invalid inner_perm"};
2057 if (!isValidPackingPermutation(packOp, perm, permType) ||
2058 !isValidPackingPermutation(unPackOp, perm, permType)) {
2060 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
2061 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
2067 assert(packOp && linalgOp &&
"unexpected null op");
2071 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
2073 assert(succeeded(res) &&
"unexpected packTranspose failure");
2076 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
2077 transformResults.
set(cast<OpResult>(getPackedOp()),
2078 {res->transposedLinalgOp});
2080 transformResults.
set(cast<OpResult>(getUnPackOp()),
2081 {res->transposedUnPackOp});
2083 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
2098 StringRef copyBackOp,
2099 bool usePrescribedTensorShapes) {
2100 auto resultType = transform::AnyOpType::get(
b.getContext());
2106 b.getI64ArrayAttr(paddingDimensions),
2109 (padToMultipleOf.empty()
2111 :
b.getDenseI64ArrayAttr(padToMultipleOf)),
2112 b.getI64ArrayAttr(nofoldFlags),
2113 b.getArrayAttr(transposePaddings),
2114 b.getStringAttr(copyBackOp),
2116 usePrescribedTensorShapes ?
b.getUnitAttr() :
nullptr);
2124 StringRef copyBackOp,
2125 bool usePrescribedTensorShapes) {
2126 auto resultType = transform::AnyOpType::get(
b.getContext());
2130 staticPadToMultipleOf);
2136 b.getI64ArrayAttr(paddingDimensions),
2137 dynamicPadToMultipleOf,
2138 staticPadToMultipleOf,
2139 b.getI64ArrayAttr(nofoldFlags),
2140 b.getArrayAttr(transposePaddings),
2142 usePrescribedTensorShapes);
2145void PadOp::getEffects(
2153SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
2155 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(),
b);
2158DiagnosedSilenceableFailure
2159transform::PadOp::apply(transform::TransformRewriter &rewriter,
2160 transform::TransformResults &results,
2161 transform::TransformState &state) {
2162 auto transformOp = cast<TransformOpInterface>(getOperation());
2163 SmallVector<Operation *> paddedOps, padOps, copyBackOps;
2166 auto linalgTarget = dyn_cast<LinalgOp>(
target);
2167 if (!linalgTarget) {
2168 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
2169 diag.attachNote(
target->getLoc()) <<
"target op";
2174 SmallVector<bool> nofoldFlags;
2175 for (int64_t packPadding :
2177 nofoldFlags.push_back(
static_cast<bool>(packPadding));
2180 SmallVector<Attribute> paddingValues;
2181 for (
auto const &[untypedAttr, elementOrTensorType] :
2182 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
2185 paddingValues.push_back(untypedAttr);
2188 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2190 emitOpError(
"expects padding values to be typed attributes or poison");
2195 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2199 if (!parsedAttr || parsedAttr.getType() != elementType) {
2201 << elementType <<
", got " << untypedAttr;
2202 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2205 paddingValues.push_back(parsedAttr);
2209 if (attr.getType() != elementType) {
2211 << elementType <<
", got " << attr;
2212 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2215 paddingValues.push_back(attr);
2219 SmallVector<SmallVector<int64_t>> transposePaddings;
2220 for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2222 cast<ArrayAttr>(transposeVector)));
2229 SmallVector<int64_t> padToMultipleOf;
2231 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2234 if (padToMultipleOf.empty())
2236 SmallVector<int64_t>(
options.paddingDimensions.size(), 1);
2238 options.padToMultipleOf = padToMultipleOf;
2239 options.paddingValues = paddingValues;
2240 options.nofoldFlags = nofoldFlags;
2241 if (getCopyBackOp() ==
2242 bufferization::MaterializeInDestinationOp::getOperationName()) {
2243 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::
2244 BufferizationMaterializeInDestination;
2245 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2246 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::LinalgCopy;
2247 }
else if (getCopyBackOp() == kCopyOpNone) {
2248 options.copyBackOp = LinalgPaddingOptions::CopyBackOp::None;
2250 llvm_unreachable(
"unsupported copy_back op");
2253 bool irChanged =
false;
2254 if (getUsePrescribedTensorShapes() &&
2255 linalgTarget.hasPureTensorSemantics()) {
2256 OpBuilder::InsertionGuard g(rewriter);
2258 for (OpOperand &operand : linalgTarget->getOpOperands()) {
2259 for (
auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2260 if (ShapedType::isStatic(dim))
2262 options.setSizeToPadTo(operand.getOperandNumber(), i,
2264 operand.get().getLoc(),
2271 SmallVector<Value> replacements;
2272 SmallVector<tensor::PadOp> newPadOps;
2274 replacements, newPadOps))) {
2280 auto diag = emitSilenceableError() <<
"failed to pad op";
2281 diag.attachNote(
target->getLoc()) <<
"target op";
2290 rewriter.
replaceOp(linalgTarget, replacements);
2291 paddedOps.push_back(paddedOp);
2292 padOps.append(newPadOps.begin(), newPadOps.end());
2293 if (
options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2294 for (Value v : replacements) {
2295 Operation *copyBackOp = v.getDefiningOp();
2296 if (!llvm::is_contained(copyBackOps, copyBackOp))
2297 copyBackOps.push_back(copyBackOp);
2302 results.
set(cast<OpResult>(getPadded()), paddedOps);
2303 results.
set(cast<OpResult>(getPad()), padOps);
2304 results.
set(cast<OpResult>(getCopy()), copyBackOps);
2308LogicalResult transform::PadOp::verify() {
2309 SmallVector<int64_t> nofoldFlags =
2311 if (any_of(nofoldFlags, [](int64_t packPadding) {
2312 return packPadding != 0 && packPadding != 1;
2315 <<
"expects nofold_flags to contain booleans (0/1), found "
2316 << getNofoldFlags();
2319 SmallVector<int64_t> paddingDimensions =
2321 if (any_of(paddingDimensions,
2322 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
2323 return emitOpError() <<
"expects padding_dimensions to contain positive "
2325 << getPaddingDimensions();
2327 if (!getMixedPadToMultipleOf().empty()) {
2328 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2329 return emitOpError() <<
"expects as many multiples as padding_dimensions";
2332 ArrayAttr transposes = getTransposePaddings();
2333 for (Attribute attr : transposes) {
2335 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2336 if (!std::is_permutation(sequence.begin(), sequence.end(),
2337 transpose.begin(), transpose.end())) {
2339 <<
"expects transpose_paddings to be a permutation, found "
2343 if (getCopyBackOp() !=
2344 bufferization::MaterializeInDestinationOp::getOperationName() &&
2345 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2346 getCopyBackOp() != kCopyOpNone)
2355void transform::PadTilingInterfaceOp::build(OpBuilder &
b,
2358 ArrayRef<int64_t> paddingSizes,
2359 bool padToMultipleOf) {
2360 auto resultType = transform::AnyOpType::get(
b.getContext());
2369 :
b.getDenseI64ArrayAttr(paddingSizes)),
2371 padToMultipleOf ?
b.getUnitAttr() :
nullptr);
2374void transform::PadTilingInterfaceOp::build(
2376 ArrayRef<OpFoldResult> mixedPaddingSizes,
bool padToMultipleOf) {
2377 auto resultType = transform::AnyOpType::get(
b.getContext());
2378 SmallVector<int64_t> staticPaddingSizes;
2379 SmallVector<Value> dynamicPaddingSizes;
2381 staticPaddingSizes);
2387 dynamicPaddingSizes,
2392void transform::PadTilingInterfaceOp::getEffects(
2393 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2400SmallVector<OpFoldResult>
2401transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2406DiagnosedSilenceableFailure
2407transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2408 transform::TransformResults &results,
2409 transform::TransformState &state) {
2410 SmallVector<Operation *> paddedOps, padOps;
2413 auto targetOp = dyn_cast<TilingInterface>(
target);
2415 auto diag = emitSilenceableError() <<
"expected TilingInterface target";
2416 diag.attachNote(
target->getLoc()) <<
"target op";
2423 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2424 auto diag = emitSilenceableError() <<
"only IndexingMapOpInterface ops "
2426 diag.attachNote(
target->getLoc()) <<
"target op";
2431 SmallVector<Attribute> paddingValues;
2432 for (
auto const &[untypedAttr, elementOrTensorType] :
2433 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2434 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2438 paddingValues.push_back(untypedAttr);
2442 emitOpError(
"expects padding values to be typed attributes or poison");
2446 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2450 if (!parsedAttr || parsedAttr.getType() != elementType) {
2452 << elementType <<
", got " << attr;
2453 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2456 paddingValues.push_back(parsedAttr);
2460 if (attr.getType() != elementType) {
2462 << elementType <<
", got " << attr;
2463 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2466 paddingValues.push_back(attr);
2470 PadTilingInterfaceOptions
options;
2471 options.setPaddingValues(paddingValues)
2472 .setPaddingSizes(getMixedPaddingSizes())
2473 .setPadToMultipleOf(getPadToMultipleOf());
2475 OpBuilder::InsertionGuard g(rewriter);
2478 rewriter, cast<TilingInterface>(targetOp.getOperation()),
options);
2479 if (
failed(maybePadOps)) {
2480 auto diag = emitSilenceableError() <<
"failed to pad op";
2481 diag.attachNote(
target->getLoc()) <<
"target op";
2484 const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
2487 paddedOps.push_back(paddedOp);
2488 padOps.append(paddedOperands.begin(), paddedOperands.end());
2489 rewriter.
replaceOp(targetOp.getOperation(), slicedResults);
2492 results.
set(cast<OpResult>(getPadded()), paddedOps);
2493 results.
set(cast<OpResult>(getPad()), padOps);
2497LogicalResult transform::PadTilingInterfaceOp::verify() {
return success(); }
2503DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2504 transform::TransformRewriter &rewriter,
2505 transform::TransformResults &transformResults,
2506 transform::TransformState &state) {
2509 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2511 <<
"requires exactly one target and one loop handle (got "
2512 << llvm::range_size(targetOps) <<
" and "
2513 << llvm::range_size(loopOps) <<
")";
2516 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2517 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2518 if (!padOp || !loopOp)
2521 FailureOr<linalg::detail::PackingResult>
result =
2527 if (
result->clonedLoopIvs.empty()) {
2528 transformResults.
set(cast<OpResult>(getPackingLoop()),
2529 {
result->hoistedPadOp.getOperation()});
2532 auto outerPackedLoop =
2534 transformResults.
set(cast<OpResult>(getPackingLoop()),
2535 {outerPackedLoop.getOperation()});
2539LogicalResult transform::HoistPadBuildPackingLoopNestOp::verify() {
2540 ArrayRef<int64_t> transpose = getTranspose();
2541 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2542 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2544 return emitOpError() <<
"expects transpose to be a permutation, found "
2550void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2551 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2558DiagnosedSilenceableFailure
2559transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2561 transform::ApplyToEachResultList &results,
2562 transform::TransformState &state) {
2563 tensor::PadOp hoistedPadOp;
2564 SmallVector<TransposeOp> transposeOps;
2565 FailureOr<Value>
result =
2567 hoistedPadOp, transposeOps);
2578 return emitDefaultSilenceableFailure(
target);
2581LogicalResult transform::HoistPadOp::verify() {
2582 ArrayRef<int64_t> transpose = getTranspose();
2583 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2584 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2586 return emitOpError() <<
"expects transpose to be a permutation, found "
2596DiagnosedSilenceableFailure
2597transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2599 transform::ApplyToEachResultList &results,
2600 transform::TransformState &state) {
2601 LinalgPromotionOptions promotionOptions;
2602 if (!getOperandsToPromote().empty())
2605 if (getUseFullTilesByDefault())
2607 getUseFullTilesByDefault());
2608 if (getUseOriginalSubviewSize())
2612 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2613 if (!getUseFullTileBuffers().empty())
2615 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2616 if (getAlignment().has_value())
2617 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2618 if (getMemorySpace().has_value())
2619 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2621 if (getMapping().has_value()) {
2623 auto mapping = *getMapping();
2624 if (mapping.size() > 1)
2625 return emitDefaultDefiniteFailure(
target);
2627 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2629 if (addressSpace.getAddressSpace() ==
2630 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2637 }
else if (addressSpace.getAddressSpace() ==
2638 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2646 return emitDefaultDefiniteFailure(
target);
2651 return emitDefaultDefiniteFailure(
target);
2656 return emitDefaultDefiniteFailure(
target);
2665DiagnosedSilenceableFailure
2666transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2667 TransformResults &transformResults,
2668 TransformState &state) {
2672 for (Operation *
target : payload) {
2673 if (
target->getNumOperands() > 0)
2675 if (!
target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2676 target->getNumRegions() > 0)
2678 <<
"expected target that is isolated from above";
2682 Operation *pattern = &getBodyRegion().front().front();
2683 SmallVector<Operation *> replacements;
2684 for (Operation *
target : payload) {
2685 if (getOperation()->isAncestor(
target))
2692 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2696void transform::ReplaceOp::getEffects(
2697 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2703LogicalResult transform::ReplaceOp::verify() {
2704 if (!getBodyRegion().hasOneBlock())
2706 if (std::distance(getBodyRegion().front().begin(),
2707 getBodyRegion().front().end()) != 1)
2708 return emitOpError() <<
"expected one operation in block";
2709 Operation *
replacement = &getBodyRegion().front().front();
2712 <<
"expected replacement without operands";
2713 if (!
replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2716 <<
"expect op that is isolated from above";
2724DiagnosedSilenceableFailure
2725transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2727 transform::ApplyToEachResultList &results,
2728 transform::TransformState &state) {
2729 scf::SCFTilingOptions tilingOptions;
2730 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &
b, Operation *) {
2731 SmallVector<OpFoldResult> tileSizes;
2732 Location loc =
target.getLoc();
2733 SmallVector<OpFoldResult> allShapeSizes =
2734 target.createFlatListOfOperandDims(
b, loc);
2735 AffineMap map =
target.getShapesToLoopsMap();
2738 SmallVector<OpFoldResult> shapeSizes =
2743 for (OpFoldResult shapeSize : shapeSizes) {
2745 :
b.getIndexAttr(1));
2750 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2751 rewriter, cast<TilingInterface>(
target.getOperation()), tilingOptions);
2752 if (
failed(maybeTilingResult))
2753 return emitDefaultDefiniteFailure(
target);
2755 if (
target->getNumResults())
2760 results.
reserve(maybeTilingResult->tiledOps.size());
2761 for (Operation *tiled : maybeTilingResult->tiledOps)
2770DiagnosedSilenceableFailure
2771transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2772 transform::TransformResults &results,
2773 transform::TransformState &state) {
2774 SmallVector<Operation *> loops;
2776 auto tilingOp = dyn_cast<TilingInterface>(*
target);
2778 DiagnosedSilenceableFailure
diag =
2779 emitSilenceableError()
2780 <<
"expected the payload to implement TilingInterface";
2781 diag.attachNote(
target->getLoc()) <<
"payload op";
2785 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2786 scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2787 if (
failed(generatedLoops))
2788 return emitDefaultDefiniteFailure(
target);
2789 for (scf::ForOp &loop : *generatedLoops) {
2790 loops.push_back(loop.getOperation());
2794 results.
set(cast<OpResult>(getResult()), loops);
2802DiagnosedSilenceableFailure
2803transform::RewriteInDestinationPassingStyleOp::applyToOne(
2804 transform::TransformRewriter &rewriter, Operation *
target,
2805 transform::ApplyToEachResultList &results,
2806 transform::TransformState &state) {
2808 FailureOr<Operation *> maybeResult =
2810 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2811 [&rewriter](
auto op) {
2815 return emitDefaultSilenceableFailure(
target);
2824DiagnosedSilenceableFailure
2825SplitOp::apply(transform::TransformRewriter &rewriter,
2826 TransformResults &results, TransformState &state) {
2828 SmallVector<Operation *> payload =
2831 bool isMultiwaySplit = getMultiway();
2833 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2835 <<
"requires exactly one target when "
2836 "multiway split is enabled (got "
2837 << llvm::range_size(payload) <<
")";
2840 SmallVector<OpFoldResult> chunkSizes;
2842 if (!isMultiwaySplit)
2843 chunkSizes.reserve(payload.size());
2845 if (getDynamicChunkSizes()) {
2847 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2848 chunkSizes = llvm::map_to_vector(
2849 state.
getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2852 diag = emitSilenceableError()
2853 <<
"expected dynamic split point handle to point to a "
2854 "single-result index-typed op";
2855 diag.attachNote(op->
getLoc()) <<
"dynamic split point";
2860 chunkSizes = llvm::map_to_vector(
2861 state.
getParams(getDynamicChunkSizes()),
2862 [](Attribute attr) {
return OpFoldResult(attr); });
2864 if (
diag.isSilenceableFailure())
2869 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2871 <<
"expected the dynamic split point handle to point to as "
2873 << chunkSizes.size() <<
") as the target handle ("
2874 << payload.size() <<
")";
2877 chunkSizes.resize(payload.size(),
2881 auto checkStructuredOpAndDimensions =
2882 [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2884 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2885 diag.attachNote(loc) <<
"target op";
2889 if (getDimension() >= linalgOp.getNumLoops()) {
2890 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2891 <<
" does not exist in target op";
2892 diag.attachNote(loc) <<
"target op";
2898 auto checkFailureInSplitting =
2899 [&](
bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2908 SmallVector<Operation *> opList;
2909 if (isMultiwaySplit) {
2912 TilingInterface head, tail;
2913 Operation *
target = payload.front();
2915 LinalgOp linalgOp = dyn_cast<LinalgOp>(
target);
2918 DiagnosedSilenceableFailure
diag =
2919 checkStructuredOpAndDimensions(linalgOp,
target->getLoc());
2920 if (
diag.isSilenceableFailure())
2923 for (
auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2926 target = tail.getOperation();
2931 linalgOp = cast<LinalgOp>(
target);
2932 Location loc =
target->getLoc();
2936 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2937 getDimension(), chunkSize);
2940 DiagnosedSilenceableFailure
diag =
2941 checkFailureInSplitting(!head && !tail, loc);
2942 if (
diag.isDefiniteFailure())
2945 opList.push_back(head.getOperation());
2950 opList.push_back(tail.getOperation());
2954 SmallVector<Operation *> first, second;
2955 Operation *noSecondPart =
nullptr;
2956 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2957 Operation *
target = std::get<0>(pair);
2958 Location loc =
target->getLoc();
2959 LinalgOp linalgOp = dyn_cast<LinalgOp>(
target);
2960 DiagnosedSilenceableFailure
diag =
2961 checkStructuredOpAndDimensions(linalgOp,
target->getLoc());
2963 if (
diag.isSilenceableFailure())
2967 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2968 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2969 getDimension(), std::get<1>(pair));
2972 DiagnosedSilenceableFailure diagSplit =
2973 checkFailureInSplitting(!first.back() && !second.back(), loc);
2978 if (!second.back()) {
2984 if (second.size() != first.size() && !second.empty()) {
2985 auto diag = emitSilenceableError()
2986 <<
"splitting does not produce the second part for a subset "
2989 <<
"expected splitting to produce the second part of all "
2990 "or none of the targets";
2992 <<
"first target with no second part";
2996 opList.append(first);
2997 if (!second.empty())
2998 opList.append(second);
3000 results.
set(cast<OpResult>(getSplitList()), opList);
3004void SplitOp::getEffects(
3005 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3007 if (getDynamicChunkSizes())
3013ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &
result) {
3014 OpAsmParser::UnresolvedOperand
target, dynamicChunkSizes;
3015 IntegerAttr staticChunkSizes;
3019 OptionalParseResult dynamicPointParseResult =
3021 if (!dynamicPointParseResult.
has_value()) {
3022 int64_t staticChunkSizesValue;
3036 if (dynamicPointParseResult.
has_value()) {
3037 Type chunkSizesType;
3050 SplitOp::getStaticChunkSizesAttrName(
result.name).getValue(),
3052 result.addTypes(targetType);
3056void SplitOp::print(OpAsmPrinter &printer) {
3057 printer <<
" " << getTarget() <<
" after ";
3058 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
3059 if (staticChunkSize != ShapedType::kDynamic)
3060 printer << staticChunkSize;
3062 printer << getDynamicChunkSizes();
3065 {getStaticChunkSizesAttrName()});
3066 printer <<
" : " << getTarget().getType();
3067 if (staticChunkSize == ShapedType::kDynamic)
3068 printer <<
", " << getDynamicChunkSizes().getType();
3071LogicalResult SplitOp::verify() {
3072 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
3073 (getDynamicChunkSizes() ==
nullptr)) {
3074 return emitOpError() <<
"expects either a dynamic or a static split "
3075 "point to be provided";
3084void transform::SplitReductionOp::build(
3085 OpBuilder &builder, OperationState &
result, Value
target,
3086 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
3087 bool useScalingAlgorithm,
bool useAlloc) {
3090 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(
result.name),
3093 SplitReductionOp::getInsertSplitDimensionAttrName(
result.name),
3095 if (innerParallel) {
3096 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(
result.name),
3099 if (useScalingAlgorithm) {
3101 SplitReductionOp::getUseScalingAlgorithmAttrName(
result.name),
3105 result.addAttribute(SplitReductionOp::getUseAllocAttrName(
result.name),
3108 auto resultType = transform::AnyOpType::get(ctx);
3109 result.addTypes({resultType, resultType, resultType, resultType});
3112DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
3113 transform::TransformRewriter &rewriter, LinalgOp
target,
3114 transform::ApplyToEachResultList &results,
3115 transform::TransformState &state) {
3117 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
3118 unsigned(getInsertSplitDimension()),
3119 bool(getInnerParallel())};
3122 FailureOr<SplitReductionResult> splitResult =
3123 (getUseScalingAlgorithm())
3127 return emitDefaultDefiniteFailure(
target);
3129 results.
push_back(splitResult->initOrAlloc);
3131 results.
push_back(splitResult->splitLinalgOp);
3132 results.
push_back(splitResult->resultCombiningLinalgOp);
3140void transform::TileReductionUsingForOp::build(
3141 OpBuilder &builder, OperationState &
result, Value
target,
3142 ArrayRef<int64_t> staticTileSizes) {
3149 auto opTy = transform::AnyOpType::get(ctx);
3155 staticTileSizesAttr);
3158DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
3159 transform::TransformRewriter &rewriter, Operation *
target,
3160 transform::ApplyToEachResultList &results,
3161 transform::TransformState &state) {
3164 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(
target);
3165 if (!partialReductionOp) {
3168 "Operation should implement PartialReductionOpInterface");
3171 SmallVector<unsigned> reductionDims =
3173 if (reductionDims.empty()) {
3174 for (
auto [idx, iteratorType] :
3175 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3176 if (iteratorType == utils::IteratorType::reduction)
3177 reductionDims.push_back(idx);
3181 scf::SCFTilingOptions
options;
3182 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
3183 options.setReductionTilingStrategy(
3186 options.setReductionDims(reductionDims);
3187 FailureOr<scf::SCFTilingResult>
result =
3188 scf::tileUsingSCF(rewriter, partialReductionOp,
options);
3192 "failed to tile using partial reduction");
3195 for (Value initValue :
result->initialValues)
3197 for (
auto *parallelTiledOp :
result->tiledOps)
3199 for (
auto *mergeOp :
result->mergeOps)
3209void transform::TileReductionUsingForallOp::build(
3210 OpBuilder &builder, OperationState &
result, Value
target,
3211 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3219 auto opTy = transform::AnyOpType::get(ctx);
3226 staticNumThreadsAttr,
3227 staticTileSizesAttr,
3231DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3232 transform::TransformRewriter &rewriter, Operation *
target,
3233 transform::ApplyToEachResultList &results,
3234 transform::TransformState &state) {
3237 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(
target);
3238 if (!partialReductionOp) {
3241 "Operation should implement PartialReductionOpInterface");
3243 SmallVector<OpFoldResult> numThreads =
3245 SmallVector<OpFoldResult> tileSizes =
3248 scf::SCFTilingOptions
options;
3249 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3250 options.setReductionTilingStrategy(
3252 if (!getNumThreads().empty()) {
3253 options.setNumThreads(numThreads);
3255 options.setTileSizes(tileSizes);
3257 if (
auto mapping = getMapping()) {
3258 options.setMapping(mapping.value().getValue());
3260 SmallVector<unsigned> reductionDims =
3262 if (reductionDims.empty()) {
3263 for (
auto [idx, iteratorType] :
3264 llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
3265 if (iteratorType == utils::IteratorType::reduction)
3266 reductionDims.push_back(idx);
3269 options.setReductionDims(reductionDims);
3270 FailureOr<scf::SCFTilingResult>
result =
3271 scf::tileUsingSCF(rewriter, partialReductionOp,
options);
3274 auto diag = emitSilenceableError() <<
"could not tile reduction";
3279 for (Value initValue :
result->initialValues)
3281 for (
auto *parallelTiledOp :
result->tiledOps)
3283 for (
auto *mergeOp :
result->mergeOps)
3293DiagnosedSilenceableFailure
3294transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3295 TransformResults &transformResults,
3296 TransformState &state) {
3298 SmallVector<Operation *> targetOps =
3301 if (!llvm::hasSingleElement(targetOps)) {
3303 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
3307 Operation *
target = *targetOps.begin();
3308 auto linalgOp = dyn_cast<LinalgOp>(
target);
3309 auto tileableOp = dyn_cast<TilingInterface>(
target);
3314 OpBuilder builder(linalgOp.getContext());
3316 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
3317 if (linalgOp.hasDynamicShape()) {
3318 auto diag = emitSilenceableError()
3319 <<
"cannot compute parametric tile sizes for dynamically "
3320 "shaped payload op";
3321 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
3325 FailureOr<StaticContinuousTileSizeSpecification> spec =
3329 return emitSilenceableError()
3330 <<
"failed to compute multi-size tiling sizes";
3333 SmallVector<int64_t> chunkSizes;
3335 for (
auto &&[tileSize, tripCount] :
3336 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3337 chunkSizes.push_back(tileSize * tripCount);
3339 auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3340 return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3345 getI64AttrsFromI64(spec->tileSizes));
3346 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
3347 getI64AttrsFromI64(chunkSizes));
3354 OpFoldResult targetSize = builder.
getIndexAttr(getTargetSize());
3355 unsigned dimension = getDimension();
3358 builder, tileableOp, dimension, targetSize,
true);
3360 return emitSilenceableError() <<
"could not generate tile size computation";
3365 auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3370 SmallVector<Value> chunkSizes;
3372 for (
auto &&[tileSize, tripCount] :
3373 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3374 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3375 chunkSizes.push_back(splitPoint);
3378 auto getDefiningOps = [&](ArrayRef<Value> values) {
3379 return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3385 getDefiningOps(spec->tileSizes));
3386 transformResults.
set(cast<OpResult>(getChunkSizes()),
3387 getDefiningOps(chunkSizes));
3392LogicalResult transform::ContinuousTileSizesOp::verify() {
3395 return emitOpError() <<
"expects all results type to be the same";
3401void transform::ContinuousTileSizesOp::getEffects(
3402 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3419 Type &tileSizesType,
3420 Type &chunkSizesType) {
3421 FunctionType funcType;
3423 if (failed(parser.
parseType<FunctionType>(funcType)))
3426 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3427 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
3428 "argument and one result";
3430 targetType = funcType.getInput(0);
3431 tileSizesType = chunkSizesType = funcType.getResult(0);
3440void transform::TileUsingForOp::build(
3442 Value
target, ArrayRef<int64_t> staticTileSizes,
3443 ArrayRef<int64_t> interchange,
3444 std::optional<ArrayRef<bool>> scalableSizes) {
3445 return build(builder,
result, loopTypes,
3449 interchange, scalableSizes);
3452void transform::TileUsingForOp::build(
3453 OpBuilder &builder, OperationState &
result, Value
target,
3454 ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3455 std::optional<ArrayRef<bool>> scalableSizes) {
3458 interchange, scalableSizes);
3461void transform::TileUsingForOp::build(
3462 OpBuilder &builder, OperationState &
result, Value
target,
3463 ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3464 std::optional<ArrayRef<bool>> scalableSizes) {
3467 SmallVector<Type> loopTypes(1, builder.
getType<transform::AnyOpType>());
3468 build(builder,
result, loopTypes,
target, mixedTileSizes, interchange,
3472void transform::TileUsingForOp::build(
3474 Value
target, ArrayRef<OpFoldResult> mixedTileSizes,
3475 ArrayRef<int64_t> interchange,
3476 std::optional<ArrayRef<bool>> scalableSizes) {
3477 SmallVector<int64_t> staticTileSizes;
3478 SmallVector<Value> dynamicTileSizes;
3484 unsigned numExpectedLoops =
3485 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3486 SmallVector<Type> resultTypes;
3487 resultTypes.reserve(numExpectedLoops);
3488 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3489 "expected one loop type or as many as loops");
3490 if (loopTypes.size() == 1)
3491 resultTypes.append(numExpectedLoops, loopTypes[0]);
3493 llvm::append_range(resultTypes, loopTypes);
3494 SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(),
false);
3495 if (scalableSizes.has_value())
3496 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3501 staticTileSizesAttr,
3503 expandedScalableSizes);
3506LogicalResult transform::TileUsingForOp::verify() {
3508 return emitOpError(
"expected same number of sizes (")
3510 << getScalableSizes().size() <<
")";
3511 ArrayRef<int64_t> staticSizes = getStaticSizes();
3512 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3513 if (getLoops().size() != numExpectedLoops)
3514 return emitOpError(
"expected number of loops to tile (")
3515 << numExpectedLoops <<
") to match number of `loops` results ("
3516 << getLoops().size() <<
")";
3520DiagnosedSilenceableFailure
3521transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3522 TransformResults &transformResults,
3523 TransformState &state) {
3524 ArrayRef<int64_t> tileSizes = getStaticSizes();
3526 SmallVector<Operation *> targets =
3528 SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3529 SmallVector<SmallVector<int64_t>> paramSizes;
3533 if (isa<TransformParamTypeInterface>(transformValue.getType())) {
3534 dynamicSizeProducers.push_back({});
3535 ArrayRef<Attribute> params = state.
getParams(transformValue);
3536 paramSizes.push_back(llvm::map_to_vector(params, [](Attribute attr) {
3537 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3540 if (paramSizes.back().size() != targets.size()) {
3541 DiagnosedSilenceableFailure
diag =
3542 emitSilenceableError()
3543 <<
"expected as many parameter values ("
3544 << dynamicSizeProducers.back().size() <<
") as target ops ("
3545 << targets.size() <<
")";
3546 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
3552 paramSizes.push_back({});
3553 dynamicSizeProducers.push_back(
3556 if (dynamicSizeProducers.back().size() != targets.size()) {
3557 DiagnosedSilenceableFailure
diag =
3558 emitSilenceableError()
3559 <<
"expected as many dynamic size-producing operations ("
3560 << dynamicSizeProducers.back().size() <<
") as target ops ("
3561 << targets.size() <<
")";
3562 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3566 for (Operation *op : dynamicSizeProducers.back()) {
3572 DiagnosedSilenceableFailure
diag =
3573 emitSilenceableError() <<
"expected sizes to be produced by ops "
3574 "with a single index-type result";
3575 diag.attachNote(op->
getLoc()) <<
"size producer op";
3576 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3581 SmallVector<Operation *> tiled;
3582 SmallVector<SmallVector<Operation *, 4>, 4> loops;
3583 loops.resize(getLoops().size());
3584 auto scalableSizes = getScalableSizes();
3585 for (
auto [i, op] : llvm::enumerate(targets)) {
3586 auto tilingInterface = dyn_cast<TilingInterface>(op);
3587 if (!tilingInterface) {
3588 DiagnosedSilenceableFailure
diag =
3589 emitSilenceableError()
3590 <<
"only ops implementing TilingInterface are supported";
3591 diag.attachNote(op->
getLoc()) <<
"target op";
3594 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3595 DiagnosedSilenceableFailure
diag =
3596 emitSilenceableError()
3597 <<
"too many tiles provided, expected at most "
3598 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
3599 << tileSizes.size();
3600 diag.attachNote(op->
getLoc()) <<
"target op";
3604 scf::SCFTilingOptions tilingOptions;
3605 if (tileSizes.empty()) {
3606 tilingOptions.setTileSizeComputationFunction(
3607 [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
3611 tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &
b,
3613 SmallVector<OpFoldResult> sizes;
3614 sizes.reserve(tileSizes.size());
3615 unsigned dynamicIdx = 0;
3617 for (
auto [ofrIdx, ofr] : llvm::enumerate(
getMixedSizes())) {
3618 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3619 if (scalableSizes[ofrIdx]) {
3621 b, getLoc(), cast<IntegerAttr>(attr).getInt());
3623 vector::VectorScaleOp::create(
b, getLoc(),
b.getIndexType());
3625 arith::MulIOp::create(
b, getLoc(), val, vscale).getResult());
3627 sizes.push_back(attr);
3631 ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3632 ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3634 assert((dynamicSizes.empty() ^ params.empty()) &&
3635 "expected either dynamic sizes or parameters");
3636 if (!params.empty()) {
3637 sizes.push_back(
b.getIndexAttr(params[index]));
3639 sizes.push_back(dynamicSizes[index]->getResult(0));
3646 tilingOptions.setInterchange(getInterchange());
3647 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3648 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3649 if (
failed(maybeTilingResult))
3652 rewriter.
replaceOp(op, maybeTilingResult->replacements);
3654 tiled.append(maybeTilingResult->tiledOps);
3655 for (
const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3656 loops[en2.index()].push_back(en2.value());
3659 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3660 for (
const auto &en : llvm::enumerate(loops))
3661 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3666SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
3668 ArrayRef<int64_t> tileSizes = getStaticSizes();
3669 SmallVector<OpFoldResult> results;
3670 results.reserve(tileSizes.size());
3671 unsigned dynamicPos = 0;
3673 for (int64_t size : tileSizes) {
3674 if (size == ShapedType::kDynamic) {
3675 results.push_back(dynamic[dynamicPos++]);
3683void transform::TileUsingForOp::getEffects(
3684 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3695void transform::TileUsingForallOp::build(OpBuilder &builder,
3697 ArrayRef<int64_t> staticTileSizes,
3698 transform::TileSizesSpec,
3700 return build(builder,
result,
3708void transform::TileUsingForallOp::build(OpBuilder &builder,
3710 ArrayRef<OpFoldResult> mixedTileSizes,
3711 transform::TileSizesSpec,
3713 SmallVector<int64_t> staticTileSizes;
3714 SmallVector<Value> dynamicTileSizes;
3720 auto operationType = transform::AnyOpType::get(ctx);
3723 TypeRange{operationType, operationType},
3730 staticTileSizesAttr,
3734void transform::TileUsingForallOp::build(OpBuilder &builder,
3736 ArrayRef<int64_t> staticNumThreads,
3737 transform::NumThreadsSpec,
3741 NumThreadsSpec(), mapping);
3744void transform::TileUsingForallOp::build(OpBuilder &builder,
3746 ArrayRef<OpFoldResult> mixedNumThreads,
3747 transform::NumThreadsSpec,
3749 SmallVector<int64_t> staticNumThreads;
3750 SmallVector<Value> dynamicNumThreads;
3757 auto operationType = transform::AnyOpType::get(ctx);
3760 TypeRange{operationType, operationType},
3766 staticNumThreadsAttr,
3773static SmallVector<OpFoldResult>
3779 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3781 for (
auto [lb,
ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3783 rewriter, loc, normalizedUbExpr, {lb,
ub, step});
3784 normalizedUbs.push_back(normalizedUb);
3786 return normalizedUbs;
3802 for (
auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3805 denormalizedIvs.push_back(
3808 return denormalizedIvs;
3819 scf::ForallOp loop) {
3836 auto normalizedForallOp = scf::ForallOp::create(
3837 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3838 loop.getOutputs(), loop.getMapping(),
3841 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3843 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3848 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3849 normalizedForallOp.getRegionIterArgs().end());
3850 Block *origLoopBlock = loop.getBody();
3851 rewriter.
mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3853 rewriter.
replaceOp(loop, normalizedForallOp);
3854 return normalizedForallOp;
3862 scf::SCFTilingResult &tilingResult) {
3864 auto tileableOp = dyn_cast<TilingInterface>(
target);
3867 transformOp.emitSilenceableError()
3868 <<
"only TilingInterface ops are supported";
3869 diag.attachNote(
target->getLoc()) <<
"target op";
3873 scf::SCFTilingOptions
options;
3874 options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3875 if (!mixedNumThreads.empty()) {
3876 options.setNumThreads(mixedNumThreads);
3878 options.setTileSizes(mixedTileSizes);
3881 options.setMapping(mapping.value().getValue());
3883 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3884 scf::tileUsingSCF(rewriter, tileableOp,
options);
3886 if (failed(maybeTilingResult))
3887 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3889 rewriter.
replaceOp(tileableOp, maybeTilingResult->replacements);
3891 tilingResult = *maybeTilingResult;
3893 if (mixedNumThreads.empty()) {
3894 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3897 scf::ForallOp normalizedForallOp =
3899 tilingResult.loops.front() = normalizedForallOp;
3909 auto transformOp = cast<TransformOpInterface>(getOperation());
3918 getPackedNumThreads()
3920 state, transformOp, mixedNumThreads, getPackedNumThreads())
3922 state, transformOp, mixedNumThreads, getMixedNumThreads());
3926 status = getPackedTileSizes()
3928 state, transformOp, mixedTileSizes, getPackedTileSizes())
3930 state, transformOp, mixedTileSizes, getMixedTileSizes());
3935 scf::SCFTilingResult tilingResult;
3937 rewriter, state, transformOp,
target, mixedNumThreads, mixedTileSizes,
3938 getMapping(), tilingResult);
3939 if (!
diag.succeeded())
3941 tileOps.push_back(tilingResult.loops.front());
3942 tiledOps.append(tilingResult.tiledOps);
3945 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3946 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3951void transform::TileUsingForallOp::getEffects(
3952 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3962SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3967SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3972LogicalResult TileUsingForallOp::verify() {
3973 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3974 static_cast<int>(getPackedNumThreads() != Value());
3975 if (numThreadsSpec > 1)
3977 "num_threads and packed_num_threads are mutually exclusive");
3978 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3979 static_cast<int>(getPackedTileSizes() != Value());
3980 if (tileSizesSpec > 1)
3982 "tile_sizes and packed_tile_sizes are mutually exclusive");
3983 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3984 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
3985 "must be specified");
3993void transform::VectorizeChildrenAndApplyPatternsOp::build(
3994 OpBuilder &builder, OperationState &
result, Value
target,
3995 bool foldTypeExtensionsIntoContract,
bool vectorizePadding,
3996 bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
3998 if (foldTypeExtensionsIntoContract) {
4000 VectorizeChildrenAndApplyPatternsOp::
4001 getFoldTypeExtensionsIntoContractAttrName(
result.name),
4004 if (vectorizePadding) {
4006 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
4010 if (vectorizeExtract) {
4012 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
4016 if (flatten1DDepthwiseConv) {
4018 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
4028struct VectorizationPattern :
public RewritePattern {
4029 explicit VectorizationPattern(MLIRContext *context,
4030 bool vectorizeExtract =
false,
4031 bool flattenConv =
false)
4032 : RewritePattern(MatchAnyOpTypeTag(), 1, context),
4033 vectorizeNDExtract(vectorizeExtract),
4034 flatten1DDepthwiseConv(flattenConv) {}
4035 LogicalResult matchAndRewrite(Operation *op,
4036 PatternRewriter &rewriter)
const override {
4039 "Unsupported Op, cannot vectorize");
4040 FailureOr<VectorizationResult> vectorResults =
4042 {}, vectorizeNDExtract,
4043 flatten1DDepthwiseConv);
4044 if (
failed(vectorResults))
4046 rewriter.
replaceOp(op, vectorResults->replacements);
4053 bool vectorizeNDExtract =
false;
4057 bool flatten1DDepthwiseConv =
false;
4061DiagnosedSilenceableFailure
4062transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
4063 transform::TransformRewriter &rewriter, Operation *
target,
4064 transform::ApplyToEachResultList &results,
4065 transform::TransformState &state) {
4066 if (!
target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
4067 auto diag = this->
emitOpError(
"requires isolated-from-above targets");
4068 diag.attachNote(
target->getLoc()) <<
"non-isolated target";
4073 RewritePatternSet patterns(ctx);
4074 patterns.
add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
4075 getFlatten_1dDepthwiseConv());
4077 if (!getDisableTransferPermutationMapLoweringPatterns())
4080 if (!getDisableMultiReductionToContractPatterns())
4085 patterns.
add<linalg::LinalgCopyVTRForwardingPattern,
4086 linalg::LinalgCopyVTWForwardingPattern>(ctx,
4088 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
4089 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
4092 patterns.
add<CopyVectorizationPattern>(ctx);
4094 if (getFoldTypeExtensionsIntoContract())
4097 if (getVectorizePadding()) {
4105 TrackingListener listener(state, *
this);
4108 GreedyRewriteConfig().setListener(&listener))))
4109 return emitDefaultDefiniteFailure(
target);
4119DiagnosedSilenceableFailure transform::VectorizeOp::apply(
4120 transform::TransformRewriter &rewriter,
4121 mlir::transform::TransformResults &transformResults,
4122 mlir::transform::TransformState &state) {
4124 if (std::empty(targets))
4126 auto transformOp = cast<TransformOpInterface>(getOperation());
4127 SmallVector<int64_t> vectorSizes;
4129 state, transformOp, getMixedVectorSizes(), vectorSizes);
4134 for (Operation *
target : targets) {
4137 <<
"Unsupported Op, cannot vectorize";
4139 FailureOr<VectorizationResult> vectorResults =
4141 getVectorizeNdExtract().value_or(
false),
4143 getAssumeDynamicDimsMatchVecSizes().value_or(
false),
4144 getCreateNamedContraction().value_or(
false));
4145 if (
failed(vectorResults)) {
4147 <<
"Attempted to vectorize, but failed";
4155void transform::VectorizeOp::getEffects(
4156 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
4162SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
4167LogicalResult transform::VectorizeOp::verify() {
4168 if (getStaticVectorSizes().size() != getScalableSizes().size())
4169 return emitOpError(
"expected same number of vector sizes (")
4170 << getStaticVectorSizes().size() <<
") and scalable sizes ("
4171 << getScalableSizes().size() <<
")";
4179DiagnosedSilenceableFailure
4180transform::HoistRedundantVectorTransfersOp::applyToOne(
4181 transform::TransformRewriter &rewriter, func::FuncOp
target,
4182 transform::ApplyToEachResultList &results,
4183 transform::TransformState &state) {
4196DiagnosedSilenceableFailure
4197transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4198 transform::TransformRewriter &rewriter, mlir::Operation *
target,
4199 transform::ApplyToEachResultList &results,
4200 transform::TransformState &state) {
4211DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
4212 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4213 transform::ApplyToEachResultList &results,
4214 transform::TransformState &state) {
4216 auto maybeTransformed =
4219 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4222 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4225 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4228 .Case([&](linalg::Conv2DNchwFchwOp op) {
4231 .Default([&](Operation *op) {
4234 if (
failed(maybeTransformed))
4235 return emitDefaultSilenceableFailure(
target);
4237 results.
push_back(maybeTransformed->first);
4239 results.
push_back(maybeTransformed->second);
4247DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
4248 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4249 transform::ApplyToEachResultList &results,
4250 transform::TransformState &state) {
4254 <<
"only elementwise flattening is supported";
4257 if (
target.getNumLoops() <= 1) {
4264 std::iota(reassociation.begin(), reassociation.end(), 0);
4265 auto maybeFlattened =
4267 if (
failed(maybeFlattened))
4269 <<
"attempted to flatten, but failed";
4270 results.
push_back(maybeFlattened->collapsedOp);
4279DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4280 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4281 transform::ApplyToEachResultList &results,
4282 transform::TransformState &state) {
4284 auto maybeTransformed =
4286 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4289 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4292 .Default([&](Operation *op) {
4295 if (
failed(maybeTransformed))
4296 return emitDefaultSilenceableFailure(
target);
4306DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4307 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4308 transform::ApplyToEachResultList &results,
4309 transform::TransformState &state) {
4311 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4312 auto maybeTransformed =
4314 .Case([&](linalg::MatmulOp op) {
4317 .Case([&](linalg::BatchMatmulOp op) {
4320 .Default(failure());
4321 if (
failed(maybeTransformed))
4331template <
typename OpTy>
4332static DiagnosedSilenceableFailure
4336 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4337 tensor::ParallelInsertSliceOp>() &&
4340 if (
auto copySource =
4341 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4349 if (isa<mlir::ParallelCombiningOpInterface>(
target.getOperation()))
4352 Value extracted = tensor::ExtractSliceOp::create(
4355 Value copied = linalg::CopyOp::create(rewriter,
target.getLoc(),
4356 target.getSource(), extracted)
4368DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4369 transform::TransformRewriter &rewriter, Operation *targetOp,
4370 transform::ApplyToEachResultList &results,
4371 transform::TransformState &state) {
4374 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4375 return doit(rewriter,
target, results, state);
4376 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4377 return doit(rewriter,
target, results, state);
4379 DiagnosedSilenceableFailure
diag =
4380 emitSilenceableError()
4381 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4382 diag.attachNote(targetOp->
getLoc()) <<
"target op";
4390DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4391 transform::TransformRewriter &rewriter, Operation *
target,
4392 transform::ApplyToEachResultList &results,
4393 transform::TransformState &state) {
4395 if (!isa<linalg::CopyOp, tensor::PadOp>(
target)) {
4396 DiagnosedSilenceableFailure
diag =
4397 emitSilenceableError()
4398 <<
"only linalg.copy and tensor.pad target ops are supported";
4399 diag.attachNote(
target->getLoc()) <<
"target op";
4402 assert(
target->getNumResults() == 1 &&
"expected single result");
4403 auto resultShapedType = cast<ShapedType>(
target->getResult(0).getType());
4404 if (!resultShapedType.hasStaticShape()) {
4405 DiagnosedSilenceableFailure
diag =
4406 emitSilenceableError()
4407 <<
"only statically sized ops of rank <= 3 are supported";
4408 diag.attachNote(
target->getLoc()) <<
"target op";
4413 int64_t desiredBitAlignment = getDesiredBitAlignment();
4414 int64_t eltBitwidth =
4415 resultShapedType.getElementType().getIntOrFloatBitWidth();
4416 if (desiredBitAlignment % eltBitwidth != 0) {
4417 desiredBitAlignment = eltBitwidth;
4420 gpu::CopyMappingInfo mapping(
4422 getTotalNumThreads(),
4423 desiredBitAlignment,
4424 resultShapedType.getShape(),
4427 resultShapedType.getElementType().getIntOrFloatBitWidth());
4428 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4429 DiagnosedSilenceableFailure
diag =
4430 emitSilenceableError()
4431 <<
"too few threads to map copy op to threads on the most minor "
4432 "dimension, given alignment and vector size constraints, try "
4433 "smaller tile size of mapping to more threads";
4434 diag.attachNote(
target->getLoc()) <<
"target op";
4440 scf::SCFTilingResult tilingResult;
4447 ArrayRef<OpFoldResult>{},
4448 b.getArrayAttr(mapping.threadMapping),
4450 if (!
diag.succeeded())
4453 results.
push_back(tilingResult.loops.front());
4454 for (
auto *op : tilingResult.tiledOps)
4463DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4464 transform::TransformRewriter &rewriter, linalg::LinalgOp
target,
4465 transform::ApplyToEachResultList &results,
4466 transform::TransformState &state) {
4468 FailureOr<Operation *> maybeTransformed = failure();
4470 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4475 .Default([&](Operation *op) {
return false; });
4478 return emitSilenceableError()
4479 <<
"this operation is not supported to convert to Winograd Conv2D";
4482 if (
failed(maybeTransformed)) {
4483 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
4490DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4491 transform::TransformRewriter &rewriter, Operation *
target,
4492 transform::ApplyToEachResultList &results,
4493 transform::TransformState &state) {
4495 FailureOr<Operation *> maybeTransformed = failure();
4498 .Case([&](linalg::WinogradFilterTransformOp op) {
4502 .Case([&](linalg::WinogradInputTransformOp op) {
4506 .Case([&](linalg::WinogradOutputTransformOp op) {
4513 DiagnosedSilenceableFailure
diag =
4514 emitSilenceableError()
4515 <<
"this operation is not supported to decompose into other operations";
4516 diag.attachNote(
target->getLoc()) <<
"target op";
4520 if (
failed(maybeTransformed)) {
4521 DiagnosedSilenceableFailure
diag =
4522 emitSilenceableError() <<
"decompose Winograd operations failed";
4523 diag.attachNote(
target->getLoc()) <<
"target op";
4531#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4533#define GET_OP_CLASSES
4534#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
memberIdxs push_back(ArrayAttr::get(parser.getContext(), values))
static llvm::ManagedStatic< PassManagerOptions > options
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Base type for affine expression.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the error.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation, bool PoisonPaddingOk=false)
Patterns to bubble up or down data layout ops across other operations.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice + copy.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Canonicalization patterns relevant to apply after tiling patterns.
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp, const GenericOpSpecializationOptions &options={})
Replace the given GenericOp with a namedOp or categoryOp.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors and memref.
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
FailureOr< LinalgOp > downscaleSizeOneWindowedConvolution(RewriterBase &rewriter, LinalgOp op)
Rewrite convolution/pooling/depthwise ops with size-1 window dimensions into lower-dimensional ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SmallVector< IntTy > extractFromIntegerArrayAttr(Attribute attr)
Extract integer values from the assumed ArrayAttr of IntegerAttr.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.