43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Debug.h"
47 #include <type_traits>
53 #define DEBUG_TYPE "linalg-transforms"
54 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55 #define DBGSNL() (llvm::dbgs() << "\n")
56 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
63 template <
typename PatternTy,
typename... Args>
66 using OpTy =
typename llvm::function_traits<
67 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
68 auto op = dyn_cast<OpTy>(operation);
73 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
78 explicit TrivialPatternRewriter(
MLIRContext *context)
81 TrivialPatternRewriter rewriter(operation->
getContext());
82 rewriter.setInsertionPoint(operation);
83 auto result = pattern.returningMatchAndRewrite(op, rewriter);
86 return cast<LinalgOp>(result->getOperation());
97 if (!isa<IntegerAttr>(ofr.get<
Attribute>()))
98 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
99 result.push_back(ofr);
104 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
106 if (params.size() != 1)
107 return transformOp.emitDefiniteFailure()
108 <<
"requires exactly one parameter associated";
109 result.push_back(params[0]);
113 auto payloadOps = state.getPayloadOps(transformValue);
114 if (!llvm::hasSingleElement(payloadOps)) {
116 transformOp.emitSilenceableError()
117 <<
"handle must be mapped to exactly one payload op";
119 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
126 transformOp.emitSilenceableError()
127 <<
"payload op must have exactly 1 index result";
147 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
149 for (
auto param : params) {
150 if (!isa<IntegerAttr>(param))
151 return transformOp.emitDefiniteFailure()
152 <<
"expected the parameter to be associated with an integer "
154 result.push_back(param);
159 for (
Operation *op : state.getPayloadOps(packedHandle)) {
162 transformOp.emitSilenceableError()
163 <<
"payload op must have exactly 1 index result";
178 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
183 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
189 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
192 options.rankReductionStrategy =
197 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
206 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
211 resultTypes.push_back(b.
getType<transform::AnyValueType>());
212 resultTypes.push_back(b.
getType<transform::AnyOpType>());
213 return build(b, result,
219 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
222 int64_t memorySpace) {
224 resultTypes.push_back(b.
getType<transform::AnyValueType>());
225 resultTypes.push_back(b.
getType<transform::AnyOpType>());
226 return build(b, result,
242 void notifyOperationInserted(
Operation *op,
244 ForwardingListener::notifyOperationInserted(op, previous);
246 if (previous.
isSet())
248 auto inserted = newOps.insert(op);
250 assert(inserted.second &&
"expected newly created op");
253 void notifyOperationErased(
Operation *op)
override {
254 ForwardingListener::notifyOperationErased(op);
268 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
269 NewOpsListener newOpsListener(previousListener);
273 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
276 }
else if (getMemcpyOp() ==
"memref.copy") {
279 }
else if (getMemcpyOp() ==
"linalg.copy") {
283 llvm_unreachable(
"invalid memcpy op");
285 if (getAllocOp() ==
"memref.alloc") {
288 }
else if (getAllocOp() ==
"memref.alloca") {
292 llvm_unreachable(
"invalid alloc op");
294 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
295 options.emitDealloc = getEmitDealloc();
299 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
301 for (
Operation *op : state.getPayloadOps(getTarget())) {
306 <<
"failed to bufferize operation";
307 diag.attachNote(op->
getLoc()) <<
"target payload op";
310 allocatedBuffers.push_back(buffer);
314 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
315 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
319 void transform::BufferizeToAllocationOp::getEffects(
321 if (getBufferizeDestinationOnly()) {
334 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
335 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
336 return emitOpError() <<
"unsupported memcpy op";
337 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
338 return emitOpError() <<
"unsupported alloc op";
351 #define DOWNSCALE(trans) \
353 FailureOr<LinalgOp> res = tryApply<trans>(target); \
354 if (succeeded(res)) { \
355 results.push_back(*res); \
356 return DiagnosedSilenceableFailure::success(); \
360 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
361 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
374 #undef DOWNSCALE_NORMAL
375 #undef DOWNSCALE_CALL
377 return emitDefaultSilenceableFailure(target);
391 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
392 if (!decomposableOp) {
394 "payload is not a decomposable op"));
395 return emitDefaultSilenceableFailure(target);
399 decomposableOp.decomposeOperation(rewriter);
400 if (
failed(maybeNewResults))
401 return emitDefaultSilenceableFailure(target);
403 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
404 for (
Value val : *maybeNewResults) {
405 Operation *definition = val.getDefiningOp();
416 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
423 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
427 options.allowReturnAllocsFromLoops =
true;
429 for (
Operation *target : state.getPayloadOps(getTarget())) {
433 <<
"failed to analyze op";
435 rewriter, target, state)))
437 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
448 template <
typename Range>
458 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
459 if (!tilingInterfaceOp)
460 return transformOp->
emitError(
"only TilingInterface ops are supported");
464 applyFn(tilingInterfaceOp);
470 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
471 for (
Operation *toReplace : opsToReplace) {
472 for (
OpResult res : toReplace->getResults())
473 if (
auto replacement = tiledResults->replacements.lookup(res))
475 if (toReplace->use_empty()) {
481 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
482 assert(tiledResults->loops.size() == numLoops &&
483 "Mismatched number of loops, tile and fuse transform should have "
485 for (
unsigned int i = 0; i < numLoops; ++i)
486 loopOps[i].push_back(tiledResults->loops[i]);
489 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
490 for (
unsigned int i = 0; i < numLoops; ++i)
491 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
501 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
503 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
509 tilingOptions = tilingOptions.
setTileSizes(tileSizesOfr);
513 rewriter, getOperation(), state.getPayloadOps(getTarget()),
514 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
515 [&](TilingInterface tilingInterfaceOp)
517 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
526 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
527 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
528 if (!std::is_permutation(sequence.begin(), sequence.end(),
529 permutation.begin(), permutation.end())) {
530 return emitOpError() <<
"expects interchange to be a permutation, found "
531 << getTileInterchange();
535 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
536 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
537 if (numExpectedLoops != getNumResults() - 1)
538 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
547 void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
550 Value containingOp) {
553 result.
addTypes({resultType, resultType});
569 (domInfo.
dominates(containingOp, user))) {
570 dominatedUsers.insert(user);
573 if (dominatedUsers.empty())
577 auto forallOp = cast<scf::ForallOp>(containingOp);
583 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
588 newOuts.push_back(outputs[resultNumber]);
591 auto newforallOp = rewriter.
create<scf::ForallOp>(
592 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
593 forallOp.getMixedStep(), newOuts, forallOp.getMapping());
595 newforallOp.getRegion().takeBody(forallOp.getRegion());
600 newforallOp.getBody()->addArgument(newOuts.back().getType(),
601 newOuts.back().getLoc());
602 auto bbArgs = newforallOp.getBody()->getArguments();
610 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
612 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; }));
613 Operation *firstYieldOp = yieldingOps.front();
616 Value dst = newforallOp.getRegionIterArgs().back();
618 rewriter.
create<tensor::ParallelInsertSliceOp>(firstYieldOp->
getLoc(), src,
619 dst, offsets, sizes, strides);
623 newforallOp->getResult(result.index()));
626 newforallOp->getResults().back(),
629 return dominatedUsers.contains(user);
640 static std::tuple<SmallVector<Operation *>,
Operation *>
643 LLVM_DEBUG(
DBGS() <<
"Try to fuse a direct extract use\n");
644 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
645 if (!tileableProducer) {
647 <<
"producer is not a TileableInterface: " << *producerOp;
654 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
655 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
656 return sliceOp && containingOp->isProperAncestor(sliceOp);
660 if (it == tileableProducer->getUsers().end()) {
661 diag.attachNote(tileableProducer->getLoc())
662 <<
"could not find fusion opportunity for: " << *tileableProducer;
665 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
672 int64_t resultNumber =
673 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
674 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
680 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
683 if (
failed(tileAndFuseResult)) {
684 diag.attachNote(tileableProducer->getLoc())
685 <<
"failed to tile producer op: " << *tileableProducer;
690 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
691 LLVM_DEBUG(
DBGS() <<
"tiledProducer: " << *tiledOp <<
"\n");
696 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
697 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
698 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
699 if (
failed(maybeRankReduced)) {
701 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
702 << tileAndFuseResult->tiledValues[0]
703 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
706 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
710 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
711 resultNumber, offsets, sizes);
713 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
726 LLVM_DEBUG(
DBGS() <<
"Try to fuse an extract use through block argument\n");
728 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
729 if (!tileableProducer) {
731 <<
"producer is not a TileableInterface: " << *producerOp;
736 scf::ForallOp forallOp;
737 auto itProducerUses =
738 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
739 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
743 if (!forallOp || forallOp != containingOp) {
744 diag.attachNote(tileableProducer->getLoc())
745 <<
"could not find a use by the containing op: " << *tileableProducer;
760 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
761 return sliceOp && containingOp->isProperAncestor(sliceOp);
765 if (itBBArgUsers == bbArg.
getUsers().end()) {
767 <<
"could not find fusion opportunity for bbArg: " << bbArg;
770 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
778 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
779 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
784 rewriter, tileableProducer->getLoc(), tileableProducer,
785 destinationTensors))) {
786 diag.attachNote(tileableProducer->getLoc())
787 <<
"failed to get destination tensors for: " << *tileableProducer;
792 bvm.
map(destinationTensors[resultNumber], bbArg);
793 auto tileableProducerClone =
794 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
796 llvm::make_scope_exit([&]() { rewriter.
eraseOp(tileableProducerClone); });
800 tileableProducerClone.generateResultTileValue(
801 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
802 sliceOpToTile.getMixedSizes());
803 if (
failed(tileAndFuseResult)) {
804 diag.attachNote(tileableProducer->getLoc())
805 <<
"failed to tile producer op: " << *tileableProducer;
810 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
811 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
812 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
813 assert(
succeeded(maybeRankReduced) &&
"unexpected shape");
814 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
819 destinationTensors.front());
822 return tileAndFuseResult->tiledOps;
828 LLVM_DEBUG(
DBGS() <<
"Try to fuse an use by cloning\n");
833 for (
OpOperand &use : result.getUses()) {
835 uses.push_back(&use);
840 if (containingOp == use.getOwner()) {
842 <<
"producer op use by containing op cannot be fused by cloning";
850 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
859 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
860 "Parallel insert slice is not a valid clone destination");
861 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
862 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
866 fusedOp = rewriter.
clone(*producerOp);
868 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
873 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
883 auto producerOps = state.getPayloadOps(getProducerOp());
884 auto containingOps = state.getPayloadOps(getContainingOp());
885 if (!llvm::hasSingleElement(containingOps)) {
887 <<
"requires exactly one containing_op handle (got "
888 << llvm::range_size(containingOps) <<
")";
890 Operation *containingOp = *containingOps.begin();
893 if (std::empty(producerOps)) {
895 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
904 for (
const auto &it :
enumerate(remainingProducers)) {
907 int64_t numUsesInContainingOp =
909 return containingOp->isAncestor(op);
914 if (numUsesInContainingOp > 0) {
915 if (numUsesInContainingOp == 1)
916 remainingProducers.erase(remainingProducers.begin() + it.index());
923 while (!remainingProducers.empty()) {
924 auto nextProducer = getNextProducer();
925 if (
failed(nextProducer)) {
927 <<
"could not find next producer to fuse into container";
928 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
936 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
943 auto [tiledOps, newContainingOp] =
945 if (!tiledOps.empty()) {
946 LLVM_DEBUG(
DBGS() <<
"\nFused a direct extract use\n" << *containingOp);
947 fusedOps.append(tiledOps);
948 if (newContainingOp) {
959 (void)replacementStatus;
961 "unable to update transform state mapping");
962 rewriter.
eraseOp(containingOp);
963 containingOp = newContainingOp;
970 rewriter,
diag, producerOp, containingOp);
971 if (!tiledContainingOpOperand.empty()) {
972 LLVM_DEBUG(
DBGS() <<
"\nFused an extract use through block argument\n"
974 fusedOps.append(tiledContainingOpOperand);
981 LLVM_DEBUG(
DBGS() <<
"\nFused an use by cloning\n" << *containingOp);
982 fusedOps.push_back(cloned);
988 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
989 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
993 void transform::FuseIntoContainingOp::getEffects(
1011 if (isa<GenericOp>(target)) {
1018 results.
push_back(generic->getOperation());
1021 return emitDefaultSilenceableFailure(target);
1034 if (!isa<GenericOp>(target)) {
1042 results.
push_back(named->getOperation());
1045 return emitDefaultSilenceableFailure(target);
1059 if (interchangeVector.empty()) {
1064 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1065 if (interchangeVector.size() != numLoops) {
1066 return emitSilenceableError()
1067 << getIteratorInterchangeAttrName() <<
" has length ("
1068 << interchangeVector.size()
1069 <<
") different from the number of loops in the target operation ("
1075 interchangeVector.end()));
1084 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1085 if (!std::is_permutation(sequence.begin(), sequence.end(),
1086 permutation.begin(), permutation.end())) {
1087 return emitOpError()
1088 <<
"expects iterator_interchange to be a permutation, found "
1089 << getIteratorInterchange();
1106 <<
"cannot lower to pad + expand + transpose";
1109 transformResults.
push_back(res->expandShapeOp);
1110 transformResults.
push_back(res->transposeOp);
1126 emitSilenceableError()
1127 <<
"cannot lower to transpose + collapse + extract";
1128 diag.attachNote(target->getLoc()) <<
"target payload op";
1131 transformResults.
push_back(res->emptyOp);
1132 transformResults.
push_back(res->transposeOp);
1133 transformResults.
push_back(res->collapseShapeOp);
1134 transformResults.
push_back(res->extractSliceOp);
1164 if (getOps().has_value())
1165 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1166 getOps()->getAsValueRange<StringAttr>().end());
1168 auto payloadOps = state.getPayloadOps(getTarget());
1169 if (!llvm::hasSingleElement(payloadOps)) {
1174 bool incorrectNumOperandTypes =
false;
1181 if (getInterface().has_value()) {
1182 auto iface = getInterface().value();
1183 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1186 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1187 !isa<TilingInterface>(op))
1189 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1190 !isa<LoopLikeOpInterface>(op))
1195 if (getOpAttrs().has_value()) {
1196 DictionaryAttr opAttrs = getOpAttrs().value();
1198 if (attr.getName() == getInterfaceAttrName() ||
1199 attr.getName() == getOpsAttrName())
1201 if (!op->
hasAttr(attr.getName()))
1203 if (op->
getAttr(attr.getName()) != attr.getValue())
1208 if (getFilterResultType().has_value()) {
1209 Type t = getFilterResultType().value();
1214 if (getFilterOperandTypes().has_value()) {
1215 mlir::ArrayAttr types = getFilterOperandTypes().value();
1218 if (types.size() == 1) {
1221 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1224 [&](
Type operandType) { return operandType == t; }))
1229 if (types.size() != operandTypes.size()) {
1230 incorrectNumOperandTypes =
true;
1234 for (
auto [attr, operandType] :
1235 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1236 auto typeattr = cast<mlir::TypeAttr>(attr);
1239 if (type != operandType)
1250 (*payloadOps.begin())->
walk(matchFun);
1251 if (incorrectNumOperandTypes)
1253 "type, then it must contain as much types as "
1254 "the number of operands in the target ops");
1255 results.
set(cast<OpResult>(getResult()), res);
1270 Type &targetType,
Type &lowSizeType,
1272 Type &splitPointType) {
1273 FunctionType funcType;
1278 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1279 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1280 "argument and one result";
1282 targetType = funcType.getInput(0);
1283 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1291 if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1292 if (target.hasDynamicShape()) {
1293 auto diag = emitSilenceableError()
1294 <<
"cannot compute parametric tile sizes for dynamically "
1295 "shaped payload op";
1296 diag.attachNote(target->getLoc()) <<
"payload op";
1301 target, getDimension(), getTargetSize(), getDivisor());
1303 return emitSilenceableError()
1304 <<
"failed to compute multi-size tiling sizes";
1307 Builder builder(target.getContext());
1308 results.
assign(llvm::map_range(
1310 spec->lowTileSize * spec->lowTripCount}),
1311 [&builder,
this](int64_t value) {
1313 cast<ParamType>(getLowSize().getType()).getType(), value);
1323 builder, target, getDimension(), targetSize, divisor);
1325 return emitSilenceableError() <<
"could not generate tile size computation";
1332 {spec->lowTileSize, spec->lowTripCount});
1333 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1334 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1335 assert(lowTileSize && highTileSize && splitPoint &&
1336 "tile sizes are not produced by operations");
1344 void transform::MultiTileSizesOp::getEffects(
1348 if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1355 if (getLowSize().getType() != getHighSize().getType() ||
1356 getLowSize().getType() != getSplitPoint().getType()) {
1357 return emitOpError() <<
"expects all results type to be the same";
1377 builder.
getContext(), GenericOp::getOperationName());
1378 build(builder, result,
1387 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1394 auto targetOps = state.getPayloadOps(getTarget());
1396 if (std::empty(targetOps)) {
1397 transformResults.
set(cast<OpResult>(getPackedOp()),
1402 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1403 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1404 return emitSilenceableError()
1405 <<
"requires target to map to exactly 1 LinalgOp (got "
1406 << llvm::range_size(targetOps) <<
")";
1409 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1410 return emitSilenceableError()
1411 <<
"requires number of packed sizes match the number of loops ("
1412 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1419 state, *
this, packedSizes, getMixedPackedSizes());
1426 transformResults.
set(cast<OpResult>(getPackedOp()),
1427 {maybeResult->packedLinalgOp.getOperation()});
1431 void transform::PackOp::getEffects(
1445 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1446 <<
" is not a valid permutation";
1449 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1450 for (
auto [s, nmo] :
1451 llvm::zip_equal(getMixedMatmulPackedSizes(),
1452 getMatmulPaddedSizesNextMultipleOf())) {
1455 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1456 return emitOpError() <<
"at most one of the packed_size and the "
1457 "padded_sizes_next_multiple_of can be nonzero "
1458 "for the matmul strategy";
1470 for (
Operation *op : state.getPayloadOps(getTarget())) {
1471 auto linalgOp = dyn_cast<LinalgOp>(op);
1482 getMixedMatmulPackedSizes(),
1484 getMatmulPaddedSizesNextMultipleOf(),
1485 getMatmulInnerDimsOrder());
1487 results.push_back(packResult->packedLinalgOp);
1490 results.push_back(linalgOp);
1492 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1498 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1502 void transform::PackGreedilyOp::getEffects(
1516 return emitOpError() << getInnerPermAttrName()
1517 <<
" is not a valid permutation";
1520 return emitOpError() << getOuterPermAttrName()
1521 <<
" is not a valid permutation";
1523 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1524 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1525 <<
" or " << getOuterPermAttrName()
1526 <<
" must be specified";
1532 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1542 template <
typename RelayoutOpTy>
1545 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1547 llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
1548 "applies to only pack or unpack operations");
1549 if (!op || permutation.empty())
1551 size_t innerRank = op.getInnerDimsPos().size();
1552 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1556 if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
1557 return permutation.size() == op.getSourceRank() &&
1560 return permutation.size() == op.getDestRank() &&
1568 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1569 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1571 if (std::empty(packOrUnpackOps)) {
1572 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1573 transformResults.
set(cast<OpResult>(getPackOp()), {});
1574 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1580 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1581 !llvm::hasSingleElement(linalgOps)) {
1582 return emitSilenceableError()
1583 <<
"requires target to map to exactly 1 "
1584 "packing op and 1 packed op ("
1585 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
1586 << llvm::range_size(linalgOps) <<
")";
1590 auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
1591 auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
1592 if ((!packOp && !unPackOp)) {
1593 return emitSilenceableError() <<
"requires target to map to a "
1594 "tensor.pack or tensor.unpack";
1596 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1597 if (!linalgOpTarget)
1598 return emitSilenceableError() <<
"requires a LinalgOp target";
1602 if (packOp && packOp.getResult().hasOneUse())
1603 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1605 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1606 if (linalgOp != linalgOpTarget) {
1608 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
1609 : StringLiteral{
"not produced by the LinalgOp target"};
1610 return emitSilenceableError() << errorMsg;
1616 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
1617 OpOperand *packUse = linalgOp.getDpsInitOperand(
1618 cast<OpResult>(unPackOp.getSource()).getResultNumber());
1620 if (!packOp || !packOp.getResult().hasOneUse())
1621 return emitSilenceableError() <<
"could not find matching pack op";
1625 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1627 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1628 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1629 ? StringLiteral{
"invalid outer_perm"}
1630 : StringLiteral{
"invalid inner_perm"};
1634 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1635 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
1641 assert(packOp && linalgOp &&
"unexpected null op");
1645 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1647 assert(
succeeded(res) &&
"unexpected packTranspose failure");
1650 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1651 transformResults.
set(cast<OpResult>(getPackedOp()),
1652 {res->transposedLinalgOp});
1654 transformResults.
set(cast<OpResult>(getUnPackOp()),
1655 {res->transposedUnPackOp});
1657 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1672 StringRef copyBackOp) {
1681 (padToMultipleOf.empty() ? ArrayAttr()
1682 : b.getI64ArrayAttr(padToMultipleOf)),
1683 b.getI64ArrayAttr(packPaddings),
1684 b.getArrayAttr(transposePaddings),
1685 b.getStringAttr(copyBackOp));
1694 for (
Operation *target : state.getPayloadOps(getTarget())) {
1695 auto linalgTarget = dyn_cast<LinalgOp>(target);
1696 if (!linalgTarget) {
1697 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
1698 diag.attachNote(target->
getLoc()) <<
"target op";
1704 for (int64_t packPadding :
1705 extractFromIntegerArrayAttr<int64_t>(getPackPaddings()))
1706 packPaddings.push_back(
static_cast<bool>(packPadding));
1710 for (
auto const &it :
1711 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1712 auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1714 emitOpError(
"expects padding values to be typed attributes");
1719 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
1723 if (!parsedAttr || parsedAttr.getType() != elementType) {
1724 auto diag = this->emitOpError(
"expects a padding that parses to ")
1725 << elementType <<
", got " << std::get<0>(it);
1726 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
1729 paddingValues.push_back(parsedAttr);
1733 if (attr.getType() != elementType) {
1734 auto diag = this->emitOpError(
"expects a padding value of type ")
1735 << elementType <<
", got " << attr;
1736 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
1739 paddingValues.push_back(attr);
1744 for (
Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1745 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1746 cast<ArrayAttr>(transposeVector)));
1751 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1753 if (getPadToMultipleOf().has_value())
1755 extractFromIntegerArrayAttr<int64_t>(*getPadToMultipleOf());
1756 options.padToMultipleOf = padToMultipleOf;
1757 options.paddingValues = paddingValues;
1758 options.packPaddings = packPaddings;
1759 if (getCopyBackOp() ==
1760 bufferization::MaterializeInDestinationOp::getOperationName()) {
1763 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1765 }
else if (getCopyBackOp() == kCopyOpNone) {
1768 llvm_unreachable(
"unsupported copy_back op");
1774 replacements, newPadOps))) {
1775 auto diag = emitSilenceableError() <<
"failed to pad op";
1776 diag.attachNote(target->
getLoc()) <<
"target op";
1785 rewriter.
replaceOp(linalgTarget, replacements);
1786 paddedOps.push_back(paddedOp);
1787 padOps.append(newPadOps.begin(), newPadOps.end());
1789 for (
Value v : replacements) {
1790 Operation *copyBackOp = v.getDefiningOp();
1791 if (!llvm::is_contained(copyBackOps, copyBackOp))
1792 copyBackOps.push_back(copyBackOp);
1797 results.
set(cast<OpResult>(getPadded()), paddedOps);
1798 results.
set(cast<OpResult>(getPad()), padOps);
1799 results.
set(cast<OpResult>(getCopy()), copyBackOps);
1805 extractFromIntegerArrayAttr<int64_t>(getPackPaddings());
1806 if (any_of(packPaddings, [](int64_t packPadding) {
1807 return packPadding != 0 && packPadding != 1;
1809 return emitOpError()
1810 <<
"expects pack_paddings to contain booleans (0/1), found "
1811 << getPackPaddings();
1815 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1816 if (any_of(paddingDimensions,
1817 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
1818 return emitOpError() <<
"expects padding_dimensions to contain positive "
1820 << getPaddingDimensions();
1822 if (getPadToMultipleOf().has_value()) {
1823 if (getPadToMultipleOf()->size() != paddingDimensions.size()) {
1824 return emitOpError() <<
"expects as many multiples as padding_dimensions";
1827 ArrayAttr transposes = getTransposePaddings();
1830 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
1831 if (!std::is_permutation(sequence.begin(), sequence.end(),
1833 return emitOpError()
1834 <<
"expects transpose_paddings to be a permutation, found "
1838 if (getCopyBackOp() !=
1839 bufferization::MaterializeInDestinationOp::getOperationName() &&
1840 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1841 getCopyBackOp() != kCopyOpNone)
1842 return emitOpError() <<
"invalid copy_back_op";
1854 auto targetOps = state.getPayloadOps(getTarget());
1855 auto loopOps = state.getPayloadOps(getLoop());
1856 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1858 <<
"requires exactly one target and one loop handle (got "
1859 << llvm::range_size(targetOps) <<
" and "
1860 << llvm::range_size(loopOps) <<
")";
1863 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1864 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
1865 if (!padOp || !loopOp)
1874 if (result->clonedLoopIvs.empty()) {
1875 transformResults.
set(cast<OpResult>(getPackingLoop()),
1876 {result->hoistedPadOp.getOperation()});
1879 auto outerPackedLoop =
1881 transformResults.
set(cast<OpResult>(getPackingLoop()),
1882 {outerPackedLoop.getOperation()});
1888 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
1889 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
1891 return emitOpError() <<
"expects transpose to be a permutation, found "
1897 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
1907 tensor::PadOp target,
1910 tensor::PadOp hoistedPadOp;
1914 hoistedPadOp, transposeOps);
1925 return emitDefaultSilenceableFailure(target);
1930 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
1931 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
1933 return emitOpError() <<
"expects transpose to be a permutation, found "
1949 if (!getOperandsToPromote().empty())
1951 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
1952 if (getUseFullTilesByDefault())
1954 getUseFullTilesByDefault());
1956 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
1957 if (!getUseFullTileBuffers().empty())
1959 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
1960 if (getAlignment().has_value())
1961 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
1962 if (getMemorySpace().has_value())
1963 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
1965 if (getMapping().has_value()) {
1967 auto mapping = *getMapping();
1968 if (mapping.size() > 1)
1969 return emitDefaultDefiniteFailure(target);
1971 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
1973 if (addressSpace.getAddressSpace() ==
1974 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
1981 }
else if (addressSpace.getAddressSpace() ==
1982 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
1990 return emitDefaultDefiniteFailure(target);
1995 return emitDefaultDefiniteFailure(target);
2000 return emitDefaultDefiniteFailure(target);
2013 auto payload = state.getPayloadOps(getTarget());
2017 if (target->getNumOperands() > 0)
2020 target->getNumRegions() > 0)
2022 <<
"expected target that is isolated from above";
2026 Operation *pattern = &getBodyRegion().front().front();
2029 if (getOperation()->isAncestor(target))
2034 replacements.push_back(replacement);
2036 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2040 void transform::ReplaceOp::getEffects(
2048 if (!getBodyRegion().hasOneBlock())
2049 return emitOpError() <<
"expected one block";
2050 if (std::distance(getBodyRegion().front().begin(),
2051 getBodyRegion().front().end()) != 1)
2052 return emitOpError() <<
"expected one operation in block";
2053 Operation *replacement = &getBodyRegion().front().front();
2056 <<
"expected replacement without operands";
2060 <<
"expect op that is isolated from above";
2078 target.createFlatListOfOperandDims(b, loc);
2079 AffineMap map = target.getShapesToLoopsMap();
2096 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2097 if (
failed(maybeTilingResult))
2098 return emitDefaultDefiniteFailure(target);
2100 if (target->getNumResults())
2101 rewriter.
replaceOp(target, maybeTilingResult->replacements);
2105 results.
reserve(maybeTilingResult->tiledOps.size());
2106 for (
Operation *tiled : maybeTilingResult->tiledOps)
2120 for (
Operation *target : state.getPayloadOps(getTarget())) {
2121 auto tilingOp = dyn_cast<TilingInterface>(*target);
2124 emitSilenceableError()
2125 <<
"expected the payload to implement TilingInterface";
2126 diag.attachNote(target->getLoc()) <<
"payload op";
2132 if (
failed(generatedLoops))
2133 return emitDefaultDefiniteFailure(target);
2134 for (scf::ForOp &loop : *generatedLoops) {
2135 loops.push_back(loop.getOperation());
2139 results.
set(cast<OpResult>(getResult()), loops);
2148 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2156 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2157 [&rewriter](
auto op) {
2161 return emitDefaultSilenceableFailure(target);
2175 llvm::to_vector(state.getPayloadOps(getTarget()));
2177 splitPoints.reserve(payload.size());
2178 if (getDynamicSplitPoint()) {
2180 if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
2181 splitPoints = llvm::to_vector(llvm::map_range(
2182 state.getPayloadOps(getDynamicSplitPoint()), [&](
Operation *op) {
2185 diag = emitSilenceableError()
2186 <<
"expected dynamic split point handle to point to a "
2187 "single-result index-typed op";
2188 diag.attachNote(op->getLoc()) <<
"dynamic split point";
2193 splitPoints = llvm::to_vector(
2194 llvm::map_range(state.getParams(getDynamicSplitPoint()),
2197 if (
diag.isSilenceableFailure())
2200 if (splitPoints.size() != payload.size()) {
2202 <<
"expected the dynamic split point handle to point to as "
2204 << splitPoints.size() <<
") as the target handle ("
2205 << payload.size() <<
")";
2208 splitPoints.resize(payload.size(),
2215 for (
const auto &pair : llvm::zip(payload, splitPoints)) {
2217 auto linalgOp = dyn_cast<LinalgOp>(target);
2219 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2220 diag.attachNote(target->
getLoc()) <<
"target op";
2224 if (getDimension() >= linalgOp.getNumLoops()) {
2225 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2226 <<
" does not exist in target op";
2227 diag.attachNote(target->
getLoc()) <<
"target op";
2232 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2233 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2234 getDimension(), std::get<1>(pair));
2237 if (!first.back() && !second.back()) {
2239 diag.attachNote(target->
getLoc()) <<
"target op";
2244 if (!second.back()) {
2245 noSecondPart = target;
2250 if (second.size() != first.size() && !second.empty()) {
2251 auto diag = emitSilenceableError()
2252 <<
"splitting does not produce the second part for a subset "
2254 diag.attachNote() <<
"expected splitting to produce the second part of all "
2255 "or none of the targets";
2257 <<
"first target with no second part";
2261 results.
set(cast<OpResult>(getFirst()), first);
2262 results.
set(cast<OpResult>(getSecond()), second);
2266 void SplitOp::getEffects(
2269 if (getDynamicSplitPoint())
2277 IntegerAttr staticSplitPoint;
2283 if (!dynamicPointParseResult.
has_value()) {
2284 int64_t staticSplitPointValue;
2298 if (dynamicPointParseResult.
has_value()) {
2299 Type splitPointType;
2312 SplitOp::getStaticSplitPointAttrName(result.
name).getValue(),
2314 result.
addTypes({targetType, targetType});
2319 printer <<
" " << getTarget() <<
" after ";
2320 int64_t staticSplitSize =
static_cast<int64_t
>(getStaticSplitPoint());
2321 if (staticSplitSize != ShapedType::kDynamic)
2322 printer << staticSplitSize;
2324 printer << getDynamicSplitPoint();
2327 {getStaticSplitPointAttrName()});
2328 printer <<
" : " << getTarget().getType();
2329 if (staticSplitSize == ShapedType::kDynamic)
2330 printer <<
", " << getDynamicSplitPoint().getType();
2334 if ((
static_cast<int64_t
>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
2335 (getDynamicSplitPoint() ==
nullptr)) {
2336 return emitOpError() <<
"expects either a dynamic or a static split "
2337 "point to be provided";
2346 void transform::SplitReductionOp::build(
2348 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
2349 bool useScalingAlgorithm,
bool useAlloc) {
2355 SplitReductionOp::getInsertSplitDimensionAttrName(result.
name),
2357 if (innerParallel) {
2358 result.
addAttribute(SplitReductionOp::getInnerParallelAttrName(result.
name),
2361 if (useScalingAlgorithm) {
2363 SplitReductionOp::getUseScalingAlgorithmAttrName(result.
name),
2371 result.
addTypes({resultType, resultType, resultType, resultType});
2380 unsigned(getInsertSplitDimension()),
2381 bool(getInnerParallel())};
2385 (getUseScalingAlgorithm())
2389 return emitDefaultDefiniteFailure(target);
2391 results.
push_back(splitResult->initOrAlloc);
2393 results.
push_back(splitResult->splitLinalgOp);
2394 results.
push_back(splitResult->resultCombiningLinalgOp);
2402 void transform::TileReductionUsingForOp::build(
2413 build(builder, result,
2416 staticTileSizesAttr);
2425 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2429 return emitDefaultSilenceableFailure(target);
2431 results.
push_back(result->parallelTiledOp);
2433 results.
push_back(result->loops.front());
2441 void transform::TileReductionUsingForallOp::build(
2444 ArrayAttr mapping) {
2454 build(builder, result,
2457 staticNumThreadsAttr,
2458 staticTileSizesAttr,
2473 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2474 numThreads, tileSizes, getMapping());
2477 auto diag = emitSilenceableError() <<
"could not tile reduction";
2478 diag.attachNote(target.getLoc()) <<
"target operation";
2482 results.
push_back(result->parallelTiledOp);
2492 void transform::TileUsingForOp::build(
2497 return build(builder, result, loopTypes,
2501 interchange, scalableSizes);
2504 void transform::TileUsingForOp::build(
2508 build(builder, result, target,
2510 interchange, scalableSizes);
2513 void transform::TileUsingForOp::build(
2520 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2524 void transform::TileUsingForOp::build(
2536 unsigned numExpectedLoops =
2537 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2539 resultTypes.reserve(numExpectedLoops);
2540 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2541 "expected one loop type or as many as loops");
2542 if (loopTypes.size() == 1)
2543 resultTypes.append(numExpectedLoops, loopTypes[0]);
2545 llvm::append_range(resultTypes, loopTypes);
2547 if (scalableSizes.has_value())
2548 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2549 build(builder, result, target.
getType(),
2553 staticTileSizesAttr,
2555 expandedScalableSizes);
2560 return emitOpError(
"expected same number of sizes (")
2562 << getScalableSizes().size() <<
")";
2573 llvm::to_vector(state.getPayloadOps(getTarget()));
2579 if (isa<ParamType>(transformValue.getType())) {
2580 dynamicSizeProducers.push_back({});
2582 paramSizes.push_back(
2583 llvm::to_vector(llvm::map_range(params, [](
Attribute attr) {
2584 return cast<IntegerAttr>(attr).getValue().getSExtValue();
2587 if (paramSizes.back().size() != targets.size()) {
2589 emitSilenceableError()
2590 <<
"expected as many parameter values ("
2591 << dynamicSizeProducers.back().size() <<
") as target ops ("
2592 << targets.size() <<
")";
2593 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
2599 paramSizes.push_back({});
2600 dynamicSizeProducers.push_back(
2601 llvm::to_vector(state.getPayloadOps(transformValue)));
2603 if (dynamicSizeProducers.back().size() != targets.size()) {
2605 emitSilenceableError()
2606 <<
"expected as many dynamic size-producing operations ("
2607 << dynamicSizeProducers.back().size() <<
") as target ops ("
2608 << targets.size() <<
")";
2609 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
2613 for (
Operation *op : dynamicSizeProducers.back()) {
2620 emitSilenceableError() <<
"expected sizes to be produced by ops "
2621 "with a single index-type result";
2622 diag.attachNote(op->
getLoc()) <<
"size producer op";
2623 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
2630 loops.resize(getLoops().size());
2631 auto scalableSizes = getScalableSizes();
2633 auto tilingInterface = dyn_cast<TilingInterface>(op);
2634 if (!tilingInterface) {
2636 emitSilenceableError()
2637 <<
"only ops implementing TilingInterface are supported";
2638 diag.attachNote(op->
getLoc()) <<
"target op";
2641 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
2643 emitSilenceableError()
2644 <<
"too many tiles provided, expected at most "
2645 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
2646 << tileSizes.size();
2647 diag.attachNote(op->
getLoc()) <<
"target op";
2652 if (tileSizes.empty()) {
2661 sizes.reserve(tileSizes.size());
2662 unsigned dynamicIdx = 0;
2665 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
2666 if (scalableSizes[ofrIdx]) {
2667 auto val = b.
create<arith::ConstantIndexOp>(
2668 getLoc(), attr.
cast<IntegerAttr>().getInt());
2674 sizes.push_back(attr);
2681 assert((dynamicSizes.empty() ^ params.empty()) &&
2682 "expected either dynamic sizes or parameters");
2683 if (!params.empty()) {
2686 sizes.push_back(dynamicSizes[index]->getResult(0));
2695 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
2696 if (
failed(maybeTilingResult))
2699 rewriter.
replaceOp(op, maybeTilingResult->replacements);
2701 tiled.append(maybeTilingResult->tiledOps);
2703 loops[en2.index()].push_back(en2.value());
2706 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
2708 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
2717 results.reserve(tileSizes.size());
2718 unsigned dynamicPos = 0;
2720 for (int64_t size : tileSizes) {
2721 if (size == ShapedType::kDynamic) {
2722 results.push_back(dynamic[dynamicPos++]);
2739 transform::TileUsingForOp::getInterchangeAttrName(result.
name),
2746 if (!interchangeVals.empty()) {
2747 p <<
" interchange = [";
2748 llvm::interleaveComma(interchangeVals, p,
2749 [&](int64_t integer) { p << integer; });
2759 FunctionType functionalType;
2760 llvm::SMLoc operandLoc;
2770 size_t numExpectedLoops =
2771 staticSizes.size() - llvm::count(staticSizes.
asArrayRef(), 0);
2772 if (functionalType.getNumResults() != numExpectedLoops + 1) {
2774 <<
"expected " << (numExpectedLoops + 1) <<
" result type(s)";
2776 if (functionalType.getNumInputs() != dynamicSizes.size() + 1) {
2778 <<
"expected " << dynamicSizes.size() + 1 <<
" operand type(s)";
2780 if (parser.
resolveOperand(target, functionalType.getInputs().front(),
2783 functionalType.getInputs().drop_front(),
2791 result.
addTypes(functionalType.getResults());
2796 p <<
' ' << getTarget();
2798 {}, getScalableSizesAttr(),
2802 (*this)->getAttrs(),
2803 {getInterchangeAttrName(getOperation()->getName()),
2804 getScalableSizesAttrName(getOperation()->getName()),
2805 getStaticSizesAttrName(getOperation()->getName())});
2810 void transform::TileUsingForOp::getEffects(
2823 void transform::TileUsingForallOp::build(
OpBuilder &builder,
2827 ArrayAttr mapping) {
2828 return build(builder, result,
2836 void transform::TileUsingForallOp::build(
OpBuilder &builder,
2840 ArrayAttr mapping) {
2850 build(builder, result,
2851 TypeRange{operationType, operationType},
2858 staticTileSizesAttr,
2862 void transform::TileUsingForallOp::build(
OpBuilder &builder,
2866 ArrayAttr mapping) {
2867 return build(builder, result, target,
2872 void transform::TileUsingForallOp::build(
OpBuilder &builder,
2876 ArrayAttr mapping) {
2887 build(builder, result,
2888 TypeRange{operationType, operationType},
2894 staticNumThreadsAttr,
2901 TransformOpInterface transformOp,
Operation *target,
2906 auto tileableOp = dyn_cast<TilingInterface>(target);
2909 transformOp.emitSilenceableError()
2910 <<
"only TilingInterface ops are supported";
2911 diag.attachNote(target->
getLoc()) <<
"target op";
2916 if (!mixedNumThreads.empty()) {
2921 rewriter, tileableOp, mixedTileSizes, mapping);
2924 if (
failed(maybeTilingResult))
2925 return transformOp.emitDefaultSilenceableFailure(tileableOp);
2926 rewriter.
replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
2928 tilingResult = *maybeTilingResult;
2936 auto transformOp = cast<TransformOpInterface>(getOperation());
2945 getPackedNumThreads()
2947 state, transformOp, mixedNumThreads, getPackedNumThreads())
2949 state, transformOp, mixedNumThreads, getMixedNumThreads());
2953 status = getPackedTileSizes()
2955 state, transformOp, mixedTileSizes, getPackedTileSizes())
2957 state, transformOp, mixedTileSizes, getMixedTileSizes());
2961 for (
Operation *target : state.getPayloadOps(getTarget())) {
2964 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
2965 getMapping(), tilingResult);
2966 if (!
diag.succeeded())
2968 tileOps.push_back(tilingResult.
tileOp);
2969 tiledOps.push_back(tilingResult.
tiledOp);
2972 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
2973 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
2978 void transform::TileUsingForallOp::getEffects(
2991 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3000 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3001 static_cast<int>(getPackedNumThreads() !=
Value());
3002 if (numThreadsSpec > 1)
3004 "num_threads and packed_num_threads are mutually exclusive");
3005 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3006 static_cast<int>(getPackedTileSizes() !=
Value());
3007 if (tileSizesSpec > 1)
3009 "tile_sizes and packed_tile_sizes are mutually exclusive");
3010 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3011 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
3012 "must be specified");
3020 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3022 bool vectorizePadding,
bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
3024 if (vectorizePadding) {
3026 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3030 if (vectorizeExtract) {
3032 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3036 if (flatten1DDepthwiseConv) {
3038 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3049 explicit VectorizationPattern(
MLIRContext *context,
3050 bool vectorizeExtract =
false,
3051 bool flattenConv =
false)
3053 vectorizeNDExtract(vectorizeExtract),
3054 flatten1DDepthwiseConv(flattenConv) {}
3057 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
3060 return vectorize(rewriter, linalgOp, {},
3061 {}, vectorizeNDExtract,
3062 flatten1DDepthwiseConv);
3068 bool vectorizeNDExtract =
false;
3072 bool flatten1DDepthwiseConv =
false;
3077 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3082 auto diag = this->emitOpError(
"requires isolated-from-above targets");
3083 diag.attachNote(target->
getLoc()) <<
"non-isolated target";
3089 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3090 getFlatten_1dDepthwiseConv());
3092 if (!getDisableTransferPermutationMapLoweringPatterns())
3095 if (!getDisableMultiReductionToContractPatterns())
3103 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3104 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3109 if (getVectorizePadding())
3116 return emitDefaultDefiniteFailure(target);
3134 llvm::SMLoc operandLoc;
3147 getVectorizeNdExtractAttrName(result.
name))))
3155 if (operandTypes.size() != dynamicSizes.size() + 1) {
3157 <<
"expected " << dynamicSizes.size() + 1 <<
" operand type(s)";
3168 result.
addAttribute(getStaticVectorSizesAttrName(result.
name), staticSizes);
3174 p <<
' ' << getTarget() <<
' ';
3175 if (!getMixedVectorSizes().empty()) {
3178 getStaticVectorSizesAttr(),
3179 {}, getScalableSizesAttr(),
3183 if (getVectorizeNdExtract())
3184 p << getVectorizeNdExtractAttrName() <<
' ';
3187 (*this)->getAttrs(),
3189 getScalableSizesAttrName(getOperation()->getName()),
3190 getStaticVectorSizesAttrName(getOperation()->getName())});
3192 p << getTarget().getType();
3193 if (!getVectorSizes().empty()) {
3195 llvm::interleaveComma(getVectorSizes(), p,
3204 auto targets = state.getPayloadOps(getTarget());
3205 if (std::empty(targets))
3212 vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
3216 if (params.size() != 1)
3218 vectorSizes.push_back(
3219 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
3223 auto szPayloads = state.getPayloadOps(sz.get<
Value>());
3224 if (!llvm::hasSingleElement(szPayloads)) {
3225 auto diag = this->emitOpError(
3226 "requires vector size handle that is mapped to 1 payload op");
3228 <<
"mapped to " << llvm::range_size(szPayloads) <<
" payload ops";
3232 Operation *szPayloadOp = *szPayloads.begin();
3235 auto diag = this->emitOpError(
3236 "requires vector size payload op with 1 index result");
3237 diag.attachNote(szPayloadOp->
getLoc()) <<
"vector size payload op";
3243 auto diag = this->emitOpError(
"requires constant vector size");
3244 diag.attachNote(szPayloadOp->
getLoc()) <<
"vector size payload op";
3248 vectorSizes.push_back(attr.getInt());
3253 if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
3256 <<
"Unsupported Op, cannot vectorize";
3261 getVectorizeNdExtract().has_value()
3262 ? getVectorizeNdExtract().value()
3265 <<
"Attempted to vectorize, but failed";
3272 void transform::VectorizeOp::getEffects(
3281 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3285 if (getStaticVectorSizes().size() != getScalableSizes().size())
3286 return emitOpError(
"expected same number of vector sizes (")
3287 << getStaticVectorSizes().size() <<
") and scalable sizes ("
3288 << getScalableSizes().size() <<
")";
3297 transform::HoistRedundantVectorTransfersOp::applyToOne(
3318 auto maybeTransformed =
3321 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3324 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3327 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3330 .Case([&](linalg::Conv2DNchwFchwOp op) {
3336 if (
failed(maybeTransformed))
3337 return emitDefaultSilenceableFailure(target);
3339 results.
push_back(maybeTransformed->first);
3341 results.
push_back(maybeTransformed->second);
3356 <<
"only elementwise flattening is supported";
3359 if (target.getNumLoops() <= 1) {
3366 std::iota(reassociation.begin(), reassociation.end(), 0);
3367 auto maybeFlattened =
3369 if (
failed(maybeFlattened))
3371 <<
"attempted to flatten, but failed";
3372 results.
push_back(maybeFlattened->collapsedOp);
3373 rewriter.
replaceOp(target, maybeFlattened->results);
3386 auto maybeTransformed =
3388 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3391 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3397 if (
failed(maybeTransformed))
3398 return emitDefaultSilenceableFailure(target);
3407 template <
typename OpTy>
3411 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3412 tensor::ParallelInsertSliceOp>() &&
3415 if (
auto copySource =
3416 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3423 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3425 target->template getParentOfType<scf::InParallelOp>());
3428 Value extracted = rewriter.
create<tensor::ExtractSliceOp>(
3429 target.getLoc(), target.getDest(), target.getMixedOffsets(),
3430 target.getMixedSizes(), target.getMixedStrides());
3431 Value copied = rewriter
3432 .
create<linalg::CopyOp>(target.getLoc(),
3433 target.getSource(), extracted)
3438 target, copied, target.getDest(), target.getMixedOffsets(),
3439 target.getMixedSizes(), target.getMixedStrides());
3441 results.
push_back(copied.getDefiningOp());
3451 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3452 return doit(rewriter, target, results, state);
3453 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3454 return doit(rewriter, target, results, state);
3457 emitSilenceableError()
3458 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3459 diag.attachNote(targetOp->
getLoc()) <<
"target op";
3472 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3474 emitSilenceableError()
3475 <<
"only linalg.copy and tensor.pad target ops are supported";
3476 diag.attachNote(target->
getLoc()) <<
"target op";
3479 assert(target->
getNumResults() == 1 &&
"expected single result");
3481 if (!resultShapedType.hasStaticShape()) {
3483 emitSilenceableError()
3484 <<
"only statically sized ops of rank <= 3 are supported";
3485 diag.attachNote(target->
getLoc()) <<
"target op";
3490 int64_t desiredBitAlignment = getDesiredBitAlignment();
3491 int64_t eltBitwidth =
3492 resultShapedType.getElementType().getIntOrFloatBitWidth();
3493 if (desiredBitAlignment % eltBitwidth != 0) {
3494 desiredBitAlignment = eltBitwidth;
3499 getTotalNumThreads(),
3500 desiredBitAlignment,
3501 resultShapedType.getShape(),
3504 resultShapedType.getElementType().getIntOrFloatBitWidth());
3505 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3507 emitSilenceableError()
3508 <<
"too few threads to map copy op to threads on the most minor "
3509 "dimension, given alignment and vector size constraints, try "
3510 "smaller tile size of mapping to more threads";
3511 diag.attachNote(target->
getLoc()) <<
"target op";
3527 if (!
diag.succeeded())
3535 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3537 #define GET_OP_CLASSES
3538 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static FailureOr< ForallTilingResult > tileToForallOpImpl(RewriterBase &b, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayRef< OpFoldResult >> nominalTileSizes, std::optional< ArrayAttr > mapping, bool omitTileOffsetBoundsCheck)
Rewrite a TilingInterface op to a tiled scf.forall.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class provides support for representing a failure result, or a valid value of type T.
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
This class represents a saved insertion point.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
State for analysis-enabled bufferization.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
Operation * getOwner() const
Return the owner of this operand.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
FailureOr< ForallTilingResult > tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, std::optional< ArrayAttr > mapping)
Same as tileToForallOp, but calculate the number of threads required using the given tileSizes.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< ForallTilingResult > tileToForallOp(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayAttr > mapping)
void hoistRedundantVectorTransfers(Operation *root)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< GenericOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, TypeRange valueTypes=TypeRange(), ArrayRef< bool > scalables={}, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
static LogicalResult failure(bool isFailure=true)
If isFailure is true a failure result is generated, otherwise a 'success' result is generated.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
@ MaterializeInDestination
Transformation to drop unit-extent dimensions from linalg.generic operations.
Vectorization pattern for memref::CopyOp.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Rewrite a TilingInterface op to a tiled scf.forall, applying tiling by numThreads.
Match and rewrite for the pattern:
Match and rewrite for the pattern:
@ BufferizationMaterializeInDestination
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > ts)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.