44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/ScopeExit.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/Debug.h"
48 #include <type_traits>
54 #define DEBUG_TYPE "linalg-transforms"
55 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56 #define DBGSNL() (llvm::dbgs() << "\n")
57 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
64 template <
typename PatternTy,
typename... Args>
67 using OpTy =
typename llvm::function_traits<
68 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
69 auto op = dyn_cast<OpTy>(operation);
74 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
79 explicit TrivialPatternRewriter(
MLIRContext *context)
82 TrivialPatternRewriter rewriter(operation->
getContext());
83 rewriter.setInsertionPoint(operation);
84 auto result = pattern.returningMatchAndRewrite(op, rewriter);
87 return cast<LinalgOp>(result->getOperation());
97 if (
auto attr = dyn_cast<Attribute>(ofr)) {
98 if (!isa<IntegerAttr>(attr))
99 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
100 result.push_back(ofr);
104 Value transformValue = cast<Value>(ofr);
105 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
107 if (params.size() != 1)
108 return transformOp.emitDefiniteFailure()
109 <<
"requires exactly one parameter associated";
110 result.push_back(params[0]);
114 auto payloadOps = state.getPayloadOps(transformValue);
115 if (!llvm::hasSingleElement(payloadOps)) {
117 transformOp.emitSilenceableError()
118 <<
"handle must be mapped to exactly one payload op";
120 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
127 transformOp.emitSilenceableError()
128 <<
"payload op must have exactly 1 index result";
148 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
150 for (
auto param : params) {
151 if (!isa<IntegerAttr>(param))
152 return transformOp.emitDefiniteFailure()
153 <<
"expected the parameter to be associated with an integer "
155 result.push_back(param);
160 for (
Operation *op : state.getPayloadOps(packedHandle)) {
161 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
163 transformOp.emitSilenceableError()
164 <<
"payload op must have exactly 1 index result";
165 diag.attachNote(op->getLoc())
166 <<
"has " << op->getNumResults() <<
" results";
169 result.push_back(op->getResult(0));
183 if (
auto attr = dyn_cast<Attribute>(paramOrHandle)) {
184 reified.push_back(cast<IntegerAttr>(attr).getInt());
186 }
else if (isa<ParamType>(cast<Value>(paramOrHandle).
getType())) {
188 if (params.size() != 1)
189 return transformOp.emitSilenceableError() <<
"expected a single param";
191 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
195 Value handle = cast<Value>(paramOrHandle);
196 if (!isa<TransformHandleTypeInterface>(handle.
getType()))
197 return transformOp.emitSilenceableError() <<
"unexpected value handle";
198 auto payload = state.getPayloadOps(handle);
199 if (!llvm::hasSingleElement(payload))
200 return transformOp.emitSilenceableError()
201 <<
"requires param or handle that is mapped to 1 payload op";
203 Operation *paramOrHandlePayloadOp = *payload.begin();
206 return transformOp.emitSilenceableError()
207 <<
"requires param or handle to be result of op with 1 index "
213 return transformOp.emitSilenceableError()
214 <<
"requires param or handle to be the result of a constant like "
217 reified.push_back(attr.getInt());
226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
231 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
236 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
241 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
247 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
250 options.rankReductionStrategy =
255 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
260 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
265 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
270 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
275 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
284 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
289 resultTypes.push_back(b.
getType<transform::AnyValueType>());
290 resultTypes.push_back(b.
getType<transform::AnyOpType>());
291 return build(b, result,
297 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
300 int64_t memorySpace) {
302 resultTypes.push_back(b.
getType<transform::AnyValueType>());
303 resultTypes.push_back(b.
getType<transform::AnyOpType>());
304 return build(b, result,
320 void notifyOperationInserted(
Operation *op,
322 ForwardingListener::notifyOperationInserted(op, previous);
324 if (previous.
isSet())
326 auto inserted = newOps.insert(op);
328 assert(inserted.second &&
"expected newly created op");
331 void notifyOperationErased(
Operation *op)
override {
332 ForwardingListener::notifyOperationErased(op);
346 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
347 NewOpsListener newOpsListener(previousListener);
351 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
354 }
else if (getMemcpyOp() ==
"memref.copy") {
357 }
else if (getMemcpyOp() ==
"linalg.copy") {
361 llvm_unreachable(
"invalid memcpy op");
363 if (getAllocOp() ==
"memref.alloc") {
366 }
else if (getAllocOp() ==
"memref.alloca") {
370 llvm_unreachable(
"invalid alloc op");
372 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
373 options.emitDealloc = getEmitDealloc();
377 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
379 for (
Operation *op : state.getPayloadOps(getTarget())) {
384 <<
"failed to bufferize operation";
385 diag.attachNote(op->
getLoc()) <<
"target payload op";
388 allocatedBuffers.push_back(buffer);
392 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
393 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
397 void transform::BufferizeToAllocationOp::getEffects(
399 if (getBufferizeDestinationOnly()) {
411 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
412 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
413 return emitOpError() <<
"unsupported memcpy op";
414 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
415 return emitOpError() <<
"unsupported alloc op";
428 #define DOWNSCALE(trans) \
430 FailureOr<LinalgOp> res = tryApply<trans>(target); \
431 if (succeeded(res)) { \
432 results.push_back(*res); \
433 return DiagnosedSilenceableFailure::success(); \
437 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
438 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
451 #undef DOWNSCALE_NORMAL
452 #undef DOWNSCALE_CALL
454 return emitDefaultSilenceableFailure(target);
468 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
469 if (!decomposableOp) {
471 "payload is not a decomposable op"));
472 return emitDefaultSilenceableFailure(target);
475 FailureOr<SmallVector<Value>> maybeNewResults =
476 decomposableOp.decomposeOperation(rewriter);
477 if (failed(maybeNewResults))
478 return emitDefaultSilenceableFailure(target);
480 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
481 for (
Value val : *maybeNewResults) {
482 Operation *definition = val.getDefiningOp();
493 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
500 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
504 options.allowReturnAllocsFromLoops =
true;
506 for (
Operation *target : state.getPayloadOps(getTarget())) {
510 <<
"failed to analyze op";
512 rewriter, target, state)))
514 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
525 template <
typename Range>
529 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
535 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
536 if (!tilingInterfaceOp)
537 return transformOp->
emitError(
"only TilingInterface ops are supported");
540 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
541 applyFn(tilingInterfaceOp);
542 if (failed(tiledResults))
547 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
548 for (
Operation *toReplace : opsToReplace) {
549 for (
OpResult res : toReplace->getResults())
550 if (
auto replacement = tiledResults->replacements.lookup(res))
552 if (toReplace->use_empty()) {
558 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
559 assert(tiledResults->loops.size() == numLoops &&
560 "Mismatched number of loops, tile and fuse transform should have "
562 for (
unsigned int i = 0; i < numLoops; ++i)
563 loopOps[i].push_back(tiledResults->loops[i]);
566 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
567 for (
unsigned int i = 0; i < numLoops; ++i)
568 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
578 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
580 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
586 tilingOptions = tilingOptions.
setTileSizes(tileSizesOfr);
590 if (getApplyCleanup()) {
593 tensor::ExtractSliceOp::getCanonicalizationPatterns(
patterns, context);
600 rewriter, getOperation(), state.getPayloadOps(getTarget()),
601 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
602 [&](TilingInterface tilingInterfaceOp)
603 -> FailureOr<scf::SCFTileAndFuseResult> {
604 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
613 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
614 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
615 if (!std::is_permutation(sequence.begin(), sequence.end(),
616 permutation.begin(), permutation.end())) {
617 return emitOpError() <<
"expects interchange to be a permutation, found "
618 << getTileInterchange();
622 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
623 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
624 if (numExpectedLoops != getNumResults() - 1)
625 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
634 void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
637 Value containingOp) {
640 result.
addTypes({resultType, resultType});
656 (domInfo.
dominates(containingOp, user))) {
657 dominatedUsers.insert(user);
660 if (dominatedUsers.empty())
664 auto forallOp = cast<scf::ForallOp>(containingOp);
670 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
675 newOuts.push_back(outputs[resultNumber]);
678 auto newforallOp = rewriter.
create<scf::ForallOp>(
679 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
680 forallOp.getMixedStep(), newOuts, forallOp.getMapping());
682 newforallOp.getRegion().takeBody(forallOp.getRegion());
687 newforallOp.getBody()->addArgument(newOuts.back().getType(),
688 newOuts.back().getLoc());
689 auto bbArgs = newforallOp.getBody()->getArguments();
697 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
699 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; }));
700 Operation *firstYieldOp = yieldingOps.front();
703 Value dst = newforallOp.getRegionIterArgs().back();
705 rewriter.
create<tensor::ParallelInsertSliceOp>(firstYieldOp->
getLoc(), src,
706 dst, offsets, sizes, strides);
710 newforallOp->getResult(result.index()));
713 newforallOp->getResults().back(),
716 return dominatedUsers.contains(user);
730 destWorklist.push_back(dst);
732 while (!destWorklist.empty()) {
733 Value currentDst = destWorklist.pop_back_val();
737 if (src == currentDst)
742 auto bbArg = dyn_cast<BlockArgument>(currentDst);
746 Block *parentBlock = bbArg.getOwner();
747 assert(parentBlock &&
"unlinked block argument");
750 assert(parentOp &&
"expected block argument with parent operation");
753 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
757 for (
auto innerIterArg : parentLoop.getRegionIterArgs()) {
759 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
760 Value loopBlockArgument =
762 destWorklist.push_back(loopBlockArgument);
775 static std::tuple<SmallVector<Operation *>,
Operation *>
778 LLVM_DEBUG(
DBGS() <<
"Try to fuse a direct extract use\n");
779 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
780 if (!tileableProducer) {
782 <<
"producer is not a TileableInterface: " << *producerOp;
789 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
790 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
791 return sliceOp && containingOp->isProperAncestor(sliceOp);
795 if (it == tileableProducer->getUsers().end()) {
796 diag.attachNote(tileableProducer->getLoc())
797 <<
"could not find fusion opportunity for: " << *tileableProducer;
800 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
813 if (LoopLikeOpInterface containerLoop =
814 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
821 cast<DestinationStyleOpInterface>(
clone).getDpsInitsMutable()) {
822 Value producerOperand =
825 containerLoop.getRegionIterArgs()) {
826 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
827 Value consumerOperand =
828 containerLoop->getOperand(bbArg->getOperandNumber());
830 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
831 initOperandPtr.set(containerIterArg);
837 tileableProducer = dyn_cast<TilingInterface>(
clone);
841 int64_t resultNumber =
842 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
843 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
848 FailureOr<TilingResult> tileAndFuseResult =
849 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
852 if (failed(tileAndFuseResult)) {
853 diag.attachNote(tileableProducer->getLoc())
854 <<
"failed to tile producer op: " << *tileableProducer;
859 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
860 LLVM_DEBUG(
DBGS() <<
"tiledProducer: " << *tiledOp <<
"\n");
865 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
866 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
867 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
868 if (failed(maybeRankReduced)) {
870 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
871 << tileAndFuseResult->tiledValues[0]
872 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
875 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
879 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
880 resultNumber, offsets, sizes);
883 if (dyn_cast<LoopLikeOpInterface>(containingOp))
884 rewriter.
eraseOp(tileableProducer);
886 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
899 LLVM_DEBUG(
DBGS() <<
"Try to fuse an extract use through block argument\n");
901 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
902 if (!tileableProducer) {
904 <<
"producer is not a TileableInterface: " << *producerOp;
909 scf::ForallOp forallOp;
910 auto itProducerUses =
911 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
912 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
916 if (!forallOp || forallOp != containingOp) {
917 diag.attachNote(tileableProducer->getLoc())
918 <<
"could not find a use by the containing op: " << *tileableProducer;
933 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
934 return sliceOp && containingOp->isProperAncestor(sliceOp);
938 if (itBBArgUsers == bbArg.
getUsers().end()) {
940 <<
"could not find fusion opportunity for bbArg: " << bbArg;
943 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
951 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
952 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
957 rewriter, tileableProducer->getLoc(), tileableProducer,
958 destinationTensors))) {
959 diag.attachNote(tileableProducer->getLoc())
960 <<
"failed to get destination tensors for: " << *tileableProducer;
965 bvm.
map(destinationTensors[resultNumber], bbArg);
966 auto tileableProducerClone =
967 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
969 llvm::make_scope_exit([&]() { rewriter.
eraseOp(tileableProducerClone); });
972 FailureOr<TilingResult> tileAndFuseResult =
973 tileableProducerClone.generateResultTileValue(
974 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
975 sliceOpToTile.getMixedSizes());
976 if (failed(tileAndFuseResult)) {
977 diag.attachNote(tileableProducer->getLoc())
978 <<
"failed to tile producer op: " << *tileableProducer;
983 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
984 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
985 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
986 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
987 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
992 destinationTensors.front());
995 return tileAndFuseResult->tiledOps;
1001 LLVM_DEBUG(
DBGS() <<
"Try to fuse an use by cloning\n");
1006 for (
OpOperand &use : result.getUses()) {
1008 uses.push_back(&use);
1013 if (containingOp == use.getOwner()) {
1015 <<
"producer op use by containing op cannot be fused by cloning";
1023 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
1032 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
1033 "Parallel insert slice is not a valid clone destination");
1034 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
1035 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
1039 fusedOp = rewriter.
clone(*producerOp);
1041 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1046 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1056 auto producerOps = state.getPayloadOps(getProducerOp());
1057 auto containingOps = state.getPayloadOps(getContainingOp());
1058 if (!llvm::hasSingleElement(containingOps)) {
1060 <<
"requires exactly one containing_op handle (got "
1061 << llvm::range_size(containingOps) <<
")";
1063 Operation *containingOp = *containingOps.begin();
1066 if (std::empty(producerOps)) {
1068 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1075 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1076 for (
const auto &it :
enumerate(remainingProducers)) {
1079 int64_t numUsesInContainingOp =
1081 return containingOp->isAncestor(op);
1086 if (numUsesInContainingOp > 0) {
1087 if (numUsesInContainingOp == 1)
1088 remainingProducers.erase(remainingProducers.begin() + it.index());
1095 while (!remainingProducers.empty()) {
1096 auto nextProducer = getNextProducer();
1097 if (failed(nextProducer)) {
1099 <<
"could not find next producer to fuse into container";
1100 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
1108 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
1115 auto [tiledOps, newContainingOp] =
1117 if (!tiledOps.empty()) {
1118 LLVM_DEBUG(
DBGS() <<
"\nFused a direct extract use\n" << *containingOp);
1119 fusedOps.append(tiledOps);
1120 if (newContainingOp) {
1128 LogicalResult replacementStatus =
1131 (void)replacementStatus;
1132 assert(succeeded(replacementStatus) &&
1133 "unable to update transform state mapping");
1134 rewriter.
eraseOp(containingOp);
1135 containingOp = newContainingOp;
1142 rewriter,
diag, producerOp, containingOp);
1143 if (!tiledContainingOpOperand.empty()) {
1144 LLVM_DEBUG(
DBGS() <<
"\nFused an extract use through block argument\n"
1146 fusedOps.append(tiledContainingOpOperand);
1153 LLVM_DEBUG(
DBGS() <<
"\nFused an use by cloning\n" << *containingOp);
1154 fusedOps.push_back(cloned);
1160 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1161 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1165 void transform::FuseIntoContainingOp::getEffects(
1183 if (isa<GenericOp>(target)) {
1189 if (succeeded(
generic)) {
1190 results.
push_back(generic->getOperation());
1193 return emitDefaultSilenceableFailure(target);
1206 if (!isa<GenericOp>(target)) {
1211 FailureOr<LinalgOp> named =
1213 if (succeeded(named)) {
1214 results.
push_back(named->getOperation());
1217 return emitDefaultSilenceableFailure(target);
1231 if (interchangeVector.empty()) {
1236 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1237 if (interchangeVector.size() != numLoops) {
1238 return emitSilenceableError()
1239 << getIteratorInterchangeAttrName() <<
" has length ("
1240 << interchangeVector.size()
1241 <<
") different from the number of loops in the target operation ("
1254 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1255 if (!std::is_permutation(sequence.begin(), sequence.end(),
1256 permutation.begin(), permutation.end())) {
1257 return emitOpError()
1258 <<
"expects iterator_interchange to be a permutation, found "
1259 << getIteratorInterchange();
1274 if (!isa<linalg::CopyOp>(targetOp)) {
1276 emitSilenceableError() <<
"only linalg.copy target ops are supported";
1277 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1281 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1282 if (!copyOp.hasPureBufferSemantics()) {
1284 emitSilenceableError()
1285 <<
"cannot transform a linalg.copy on tensors into a memref.copy";
1286 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1292 assert(inputs.size() == 1 &&
"expected linalg copy op with one input");
1293 assert(outputs.size() == 1 &&
"expected memref copy op with one output");
1294 Value input = inputs.front();
1295 Value output = outputs.front();
1300 if (!isa<ShapedType>(input.
getType())) {
1302 emitSilenceableError()
1303 <<
"cannot transform a linalg.copy which input has no shape";
1304 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1309 assert(isa<ShapedType>(output.
getType()));
1311 if (cast<ShapedType>(input.
getType()).getElementType() !=
1312 cast<ShapedType>(output.
getType()).getElementType()) {
1314 emitSilenceableError()
1315 <<
"cannot transform a linalg.copy with different source and "
1316 "destination element types ";
1317 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1338 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1339 FailureOr<LowerPackResult> res =
1340 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1343 <<
"cannot lower to pad + expand + transpose";
1346 transformResults.
push_back(res->expandShapeOp);
1347 transformResults.
push_back(res->transposeOp);
1360 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1361 FailureOr<LowerUnPackOpResult> res =
1362 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1365 emitSilenceableError()
1366 <<
"cannot lower to transpose + collapse + extract";
1367 diag.attachNote(target->getLoc()) <<
"target payload op";
1370 transformResults.
push_back(res->emptyOp);
1371 transformResults.
push_back(res->transposeOp);
1372 transformResults.
push_back(res->collapseShapeOp);
1373 transformResults.
push_back(res->extractSliceOp);
1403 if (getOps().has_value())
1404 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1406 auto payloadOps = state.getPayloadOps(getTarget());
1407 if (!llvm::hasSingleElement(payloadOps)) {
1412 bool incorrectNumOperandTypes =
false;
1419 if (getInterface().has_value()) {
1420 auto iface = getInterface().value();
1421 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1424 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1425 !isa<TilingInterface>(op))
1427 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1428 !isa<LoopLikeOpInterface>(op))
1433 if (getOpAttrs().has_value()) {
1434 DictionaryAttr opAttrs = getOpAttrs().value();
1436 if (attr.getName() == getInterfaceAttrName() ||
1437 attr.getName() == getOpsAttrName())
1439 if (!op->
hasAttr(attr.getName()))
1441 if (op->
getAttr(attr.getName()) != attr.getValue())
1446 if (getFilterResultType().has_value()) {
1447 Type t = getFilterResultType().value();
1452 if (getFilterOperandTypes().has_value()) {
1453 mlir::ArrayAttr types = getFilterOperandTypes().value();
1456 if (types.size() == 1) {
1459 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1460 Type t = cast<::mlir::Type>(typeattr.getValue());
1462 [&](
Type operandType) { return operandType == t; }))
1467 if (types.size() != operandTypes.size()) {
1468 incorrectNumOperandTypes =
true;
1472 for (
auto [attr, operandType] :
1473 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1474 auto typeattr = cast<mlir::TypeAttr>(attr);
1475 Type type = cast<::mlir::Type>(typeattr.getValue());
1477 if (type != operandType)
1488 (*payloadOps.begin())->
walk(matchFun);
1489 if (incorrectNumOperandTypes)
1491 "type, then it must contain as much types as "
1492 "the number of operands in the target ops");
1493 results.
set(cast<OpResult>(getResult()), res);
1508 Type &targetType,
Type &lowSizeType,
1510 Type &splitPointType) {
1511 FunctionType funcType;
1513 if (failed(parser.
parseType<FunctionType>(funcType)))
1516 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1517 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1518 "argument and one result";
1520 targetType = funcType.getInput(0);
1521 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1529 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1530 if (target.hasDynamicShape()) {
1531 auto diag = emitSilenceableError()
1532 <<
"cannot compute parametric tile sizes for dynamically "
1533 "shaped payload op";
1534 diag.attachNote(target->getLoc()) <<
"payload op";
1539 target, getDimension(), getTargetSize(), getDivisor());
1541 return emitSilenceableError()
1542 <<
"failed to compute multi-size tiling sizes";
1545 Builder builder(target.getContext());
1546 results.
assign(llvm::map_range(
1548 spec->lowTileSize * spec->lowTripCount}),
1549 [&builder,
this](int64_t value) {
1561 builder, target, getDimension(), targetSize, divisor);
1563 return emitSilenceableError() <<
"could not generate tile size computation";
1570 {spec->lowTileSize, spec->lowTripCount});
1571 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1572 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1573 assert(lowTileSize && highTileSize && splitPoint &&
1574 "tile sizes are not produced by operations");
1582 void transform::MultiTileSizesOp::getEffects(
1586 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1595 return emitOpError() <<
"expects all results type to be the same";
1615 builder.
getContext(), GenericOp::getOperationName());
1616 build(builder, result,
1625 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1632 auto targetOps = state.getPayloadOps(getTarget());
1634 if (std::empty(targetOps)) {
1635 transformResults.
set(cast<OpResult>(getPackedOp()),
1640 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1641 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1642 return emitSilenceableError()
1643 <<
"requires target to map to exactly 1 LinalgOp (got "
1644 << llvm::range_size(targetOps) <<
")";
1647 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1648 return emitSilenceableError()
1649 <<
"requires number of packed sizes match the number of loops ("
1650 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1657 state, *
this, packedSizes, getMixedPackedSizes());
1660 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1661 if (failed(maybeResult))
1664 transformResults.
set(cast<OpResult>(getPackedOp()),
1665 {maybeResult->packedLinalgOp.getOperation()});
1669 void transform::PackOp::getEffects(
1683 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1684 <<
" is not a valid permutation";
1687 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1688 for (
auto [s, nmo] :
1689 llvm::zip_equal(getMixedMatmulPackedSizes(),
1690 getMatmulPaddedSizesNextMultipleOf())) {
1693 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1694 return emitOpError() <<
"at most one of the packed_size and the "
1695 "padded_sizes_next_multiple_of can be nonzero "
1696 "for the matmul strategy";
1708 for (
Operation *op : state.getPayloadOps(getTarget())) {
1709 auto linalgOp = dyn_cast<LinalgOp>(op);
1720 getMixedMatmulPackedSizes(),
1722 getMatmulPaddedSizesNextMultipleOf(),
1723 getMatmulInnerDimsOrder());
1724 if (succeeded(packResult)) {
1725 results.push_back(packResult->packedLinalgOp);
1728 results.push_back(linalgOp);
1730 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1736 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1740 void transform::PackGreedilyOp::getEffects(
1754 return emitOpError() << getInnerPermAttrName()
1755 <<
" is not a valid permutation";
1758 return emitOpError() << getOuterPermAttrName()
1759 <<
" is not a valid permutation";
1761 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1762 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1763 <<
" or " << getOuterPermAttrName()
1764 <<
" must be specified";
1770 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1780 template <
typename RelayoutOpTy>
1783 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1785 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1786 "applies to only pack or unpack operations");
1787 if (!op || permutation.empty())
1789 size_t innerRank = op.getInnerDimsPos().size();
1790 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1794 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1795 return permutation.size() == op.getSourceRank() &&
1798 return permutation.size() == op.getDestRank() &&
1806 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1807 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1809 if (std::empty(packOrUnpackOps)) {
1810 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1811 transformResults.
set(cast<OpResult>(getPackOp()), {});
1812 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1818 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1819 !llvm::hasSingleElement(linalgOps)) {
1820 return emitSilenceableError()
1821 <<
"requires target to map to exactly 1 "
1822 "packing op and 1 packed op ("
1823 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
1824 << llvm::range_size(linalgOps) <<
")";
1828 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1829 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1830 if ((!packOp && !unPackOp)) {
1831 return emitSilenceableError() <<
"requires target to map to a "
1832 "linalg.pack or linalg.unpack";
1834 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1835 if (!linalgOpTarget)
1836 return emitSilenceableError() <<
"requires a LinalgOp target";
1840 if (packOp && packOp.getResult().hasOneUse())
1841 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1843 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1844 if (linalgOp != linalgOpTarget) {
1846 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
1847 : StringLiteral{
"not produced by the LinalgOp target"};
1848 return emitSilenceableError() << errorMsg;
1854 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
1855 OpOperand *packUse = linalgOp.getDpsInitOperand(
1856 cast<OpResult>(unPackOp.getSource()).getResultNumber());
1858 if (!packOp || !packOp.getResult().hasOneUse())
1859 return emitSilenceableError() <<
"could not find matching pack op";
1863 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1865 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1866 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1867 ? StringLiteral{
"invalid outer_perm"}
1868 : StringLiteral{
"invalid inner_perm"};
1872 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1873 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
1879 assert(packOp && linalgOp &&
"unexpected null op");
1883 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1885 assert(succeeded(res) &&
"unexpected packTranspose failure");
1888 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1889 transformResults.
set(cast<OpResult>(getPackedOp()),
1890 {res->transposedLinalgOp});
1892 transformResults.
set(cast<OpResult>(getUnPackOp()),
1893 {res->transposedUnPackOp});
1895 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1910 StringRef copyBackOp) {
1920 (padToMultipleOf.empty()
1922 : b.getDenseI64ArrayAttr(padToMultipleOf)),
1923 b.getI64ArrayAttr(nofoldFlags),
1924 b.getArrayAttr(transposePaddings),
1925 b.getStringAttr(copyBackOp));
1933 StringRef copyBackOp) {
1938 staticPadToMultipleOf);
1945 dynamicPadToMultipleOf,
1946 staticPadToMultipleOf,
1952 void PadOp::getEffects(
1962 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1969 auto transformOp = cast<TransformOpInterface>(getOperation());
1972 for (
Operation *target : state.getPayloadOps(getTarget())) {
1973 auto linalgTarget = dyn_cast<LinalgOp>(target);
1974 if (!linalgTarget) {
1975 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
1976 diag.attachNote(target->
getLoc()) <<
"target op";
1982 for (int64_t packPadding :
1983 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1984 nofoldFlags.push_back(
static_cast<bool>(packPadding));
1988 for (
auto const &it :
1989 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1990 auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1992 emitOpError(
"expects padding values to be typed attributes");
1997 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2001 if (!parsedAttr || parsedAttr.getType() != elementType) {
2002 auto diag = this->emitOpError(
"expects a padding that parses to ")
2003 << elementType <<
", got " << std::get<0>(it);
2004 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2007 paddingValues.push_back(parsedAttr);
2011 if (attr.getType() != elementType) {
2012 auto diag = this->emitOpError(
"expects a padding value of type ")
2013 << elementType <<
", got " << attr;
2014 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2017 paddingValues.push_back(attr);
2022 for (
Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2023 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2024 cast<ArrayAttr>(transposeVector)));
2029 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2033 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2036 if (padToMultipleOf.empty())
2040 options.padToMultipleOf = padToMultipleOf;
2041 options.paddingValues = paddingValues;
2042 options.nofoldFlags = nofoldFlags;
2043 if (getCopyBackOp() ==
2044 bufferization::MaterializeInDestinationOp::getOperationName()) {
2047 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2049 }
else if (getCopyBackOp() == kCopyOpNone) {
2052 llvm_unreachable(
"unsupported copy_back op");
2058 replacements, newPadOps))) {
2059 auto diag = emitSilenceableError() <<
"failed to pad op";
2060 diag.attachNote(target->
getLoc()) <<
"target op";
2069 rewriter.
replaceOp(linalgTarget, replacements);
2070 paddedOps.push_back(paddedOp);
2071 padOps.append(newPadOps.begin(), newPadOps.end());
2073 for (
Value v : replacements) {
2074 Operation *copyBackOp = v.getDefiningOp();
2075 if (!llvm::is_contained(copyBackOps, copyBackOp))
2076 copyBackOps.push_back(copyBackOp);
2081 results.
set(cast<OpResult>(getPadded()), paddedOps);
2082 results.
set(cast<OpResult>(getPad()), padOps);
2083 results.
set(cast<OpResult>(getCopy()), copyBackOps);
2089 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2090 if (any_of(nofoldFlags, [](int64_t packPadding) {
2091 return packPadding != 0 && packPadding != 1;
2093 return emitOpError()
2094 <<
"expects nofold_flags to contain booleans (0/1), found "
2095 << getNofoldFlags();
2099 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2100 if (any_of(paddingDimensions,
2101 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
2102 return emitOpError() <<
"expects padding_dimensions to contain positive "
2104 << getPaddingDimensions();
2106 if (!getMixedPadToMultipleOf().empty()) {
2107 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2108 return emitOpError() <<
"expects as many multiples as padding_dimensions";
2111 ArrayAttr transposes = getTransposePaddings();
2114 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2115 if (!std::is_permutation(sequence.begin(), sequence.end(),
2117 return emitOpError()
2118 <<
"expects transpose_paddings to be a permutation, found "
2122 if (getCopyBackOp() !=
2123 bufferization::MaterializeInDestinationOp::getOperationName() &&
2124 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2125 getCopyBackOp() != kCopyOpNone)
2126 return emitOpError() <<
"invalid copy_back_op";
2138 auto targetOps = state.getPayloadOps(getTarget());
2139 auto loopOps = state.getPayloadOps(getLoop());
2140 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2142 <<
"requires exactly one target and one loop handle (got "
2143 << llvm::range_size(targetOps) <<
" and "
2144 << llvm::range_size(loopOps) <<
")";
2147 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2148 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2149 if (!padOp || !loopOp)
2152 FailureOr<linalg::detail::PackingResult> result =
2158 if (result->clonedLoopIvs.empty()) {
2159 transformResults.
set(cast<OpResult>(getPackingLoop()),
2160 {result->hoistedPadOp.getOperation()});
2163 auto outerPackedLoop =
2165 transformResults.
set(cast<OpResult>(getPackingLoop()),
2166 {outerPackedLoop.getOperation()});
2172 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2173 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
2175 return emitOpError() <<
"expects transpose to be a permutation, found "
2181 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2191 tensor::PadOp target,
2194 tensor::PadOp hoistedPadOp;
2196 FailureOr<Value> result =
2198 hoistedPadOp, transposeOps);
2199 if (succeeded(result)) {
2209 return emitDefaultSilenceableFailure(target);
2214 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2215 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
2217 return emitOpError() <<
"expects transpose to be a permutation, found "
2233 if (!getOperandsToPromote().empty())
2235 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2236 if (getUseFullTilesByDefault())
2238 getUseFullTilesByDefault());
2240 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2241 if (!getUseFullTileBuffers().empty())
2243 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2244 if (getAlignment().has_value())
2245 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2246 if (getMemorySpace().has_value())
2247 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2249 if (getMapping().has_value()) {
2251 auto mapping = *getMapping();
2252 if (mapping.size() > 1)
2253 return emitDefaultDefiniteFailure(target);
2255 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2257 if (addressSpace.getAddressSpace() ==
2258 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2265 }
else if (addressSpace.getAddressSpace() ==
2266 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2274 return emitDefaultDefiniteFailure(target);
2279 return emitDefaultDefiniteFailure(target);
2282 FailureOr<LinalgOp> res =
promoteSubViews(rewriter, target, promotionOptions);
2284 return emitDefaultDefiniteFailure(target);
2297 auto payload = state.getPayloadOps(getTarget());
2301 if (target->getNumOperands() > 0)
2304 target->getNumRegions() > 0)
2306 <<
"expected target that is isolated from above";
2310 Operation *pattern = &getBodyRegion().front().front();
2313 if (getOperation()->isAncestor(target))
2318 replacements.push_back(replacement);
2320 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2324 void transform::ReplaceOp::getEffects(
2332 if (!getBodyRegion().hasOneBlock())
2333 return emitOpError() <<
"expected one block";
2334 if (std::distance(getBodyRegion().front().begin(),
2335 getBodyRegion().front().end()) != 1)
2336 return emitOpError() <<
"expected one operation in block";
2337 Operation *replacement = &getBodyRegion().front().front();
2340 <<
"expected replacement without operands";
2344 <<
"expect op that is isolated from above";
2362 target.createFlatListOfOperandDims(b, loc);
2363 AffineMap map = target.getShapesToLoopsMap();
2378 FailureOr<scf::SCFTilingResult> maybeTilingResult =
tileUsingSCF(
2379 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2380 if (failed(maybeTilingResult))
2381 return emitDefaultDefiniteFailure(target);
2383 if (target->getNumResults())
2384 rewriter.
replaceOp(target, maybeTilingResult->mergeResult.replacements);
2388 results.
reserve(maybeTilingResult->tiledOps.size());
2389 for (
Operation *tiled : maybeTilingResult->tiledOps)
2403 for (
Operation *target : state.getPayloadOps(getTarget())) {
2404 auto tilingOp = dyn_cast<TilingInterface>(*target);
2407 emitSilenceableError()
2408 <<
"expected the payload to implement TilingInterface";
2409 diag.attachNote(target->getLoc()) <<
"payload op";
2413 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2415 if (failed(generatedLoops))
2416 return emitDefaultDefiniteFailure(target);
2417 for (scf::ForOp &loop : *generatedLoops) {
2418 loops.push_back(loop.getOperation());
2422 results.
set(cast<OpResult>(getResult()), loops);
2431 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2436 FailureOr<Operation *> maybeResult =
2438 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2439 [&rewriter](
auto op) {
2442 if (failed(maybeResult))
2443 return emitDefaultSilenceableFailure(target);
2457 llvm::to_vector(state.getPayloadOps(getTarget()));
2459 bool isMultiwaySplit = getMultiway();
2461 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2463 <<
"requires exactly one target when "
2464 "multiway split is enabled (got "
2465 << llvm::range_size(payload) <<
")";
2470 if (!isMultiwaySplit)
2471 chunkSizes.reserve(payload.size());
2473 if (getDynamicChunkSizes()) {
2475 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2476 chunkSizes = llvm::to_vector(llvm::map_range(
2477 state.getPayloadOps(getDynamicChunkSizes()), [&](
Operation *op) {
2480 diag = emitSilenceableError()
2481 <<
"expected dynamic split point handle to point to a "
2482 "single-result index-typed op";
2483 diag.attachNote(op->getLoc()) <<
"dynamic split point";
2488 chunkSizes = llvm::to_vector(
2489 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2492 if (
diag.isSilenceableFailure())
2497 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2499 <<
"expected the dynamic split point handle to point to as "
2501 << chunkSizes.size() <<
") as the target handle ("
2502 << payload.size() <<
")";
2505 chunkSizes.resize(payload.size(),
2509 auto checkStructuredOpAndDimensions =
2512 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2513 diag.attachNote(loc) <<
"target op";
2517 if (getDimension() >= linalgOp.getNumLoops()) {
2518 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2519 <<
" does not exist in target op";
2520 diag.attachNote(loc) <<
"target op";
2526 auto checkFailureInSplitting =
2530 diag.attachNote(loc) <<
"target op";
2537 if (isMultiwaySplit) {
2540 TilingInterface head, tail;
2543 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2547 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2548 if (
diag.isSilenceableFailure())
2554 target = tail.getOperation();
2559 linalgOp = cast<LinalgOp>(target);
2564 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2565 getDimension(), chunkSize);
2569 checkFailureInSplitting(!head && !tail, loc);
2570 if (
diag.isDefiniteFailure())
2573 opList.push_back(head.getOperation());
2578 opList.push_back(tail.getOperation());
2584 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2587 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2589 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2591 if (
diag.isSilenceableFailure())
2595 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2596 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2597 getDimension(), std::get<1>(pair));
2601 checkFailureInSplitting(!first.back() && !second.back(), loc);
2606 if (!second.back()) {
2607 noSecondPart = target;
2612 if (second.size() != first.size() && !second.empty()) {
2613 auto diag = emitSilenceableError()
2614 <<
"splitting does not produce the second part for a subset "
2617 <<
"expected splitting to produce the second part of all "
2618 "or none of the targets";
2620 <<
"first target with no second part";
2624 opList.append(first);
2626 opList.append(second);
2628 results.
set(cast<OpResult>(getSplitList()), opList);
2632 void SplitOp::getEffects(
2635 if (getDynamicChunkSizes())
2643 IntegerAttr staticChunkSizes;
2649 if (!dynamicPointParseResult.
has_value()) {
2650 int64_t staticChunkSizesValue;
2651 if (failed(parser.
parseInteger(staticChunkSizesValue)))
2664 if (dynamicPointParseResult.
has_value()) {
2665 Type ChunkSizesType;
2666 if (failed(*dynamicPointParseResult) || parser.
parseComma() ||
2678 SplitOp::getStaticChunkSizesAttrName(result.
name).getValue(),
2685 printer <<
" " << getTarget() <<
" after ";
2686 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
2687 if (staticChunkSize != ShapedType::kDynamic)
2688 printer << staticChunkSize;
2690 printer << getDynamicChunkSizes();
2693 {getStaticChunkSizesAttrName()});
2694 printer <<
" : " << getTarget().getType();
2695 if (staticChunkSize == ShapedType::kDynamic)
2696 printer <<
", " << getDynamicChunkSizes().getType();
2700 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2701 (getDynamicChunkSizes() ==
nullptr)) {
2702 return emitOpError() <<
"expects either a dynamic or a static split "
2703 "point to be provided";
2712 void transform::SplitReductionOp::build(
2714 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
2715 bool useScalingAlgorithm,
bool useAlloc) {
2721 SplitReductionOp::getInsertSplitDimensionAttrName(result.
name),
2723 if (innerParallel) {
2724 result.
addAttribute(SplitReductionOp::getInnerParallelAttrName(result.
name),
2727 if (useScalingAlgorithm) {
2729 SplitReductionOp::getUseScalingAlgorithmAttrName(result.
name),
2737 result.
addTypes({resultType, resultType, resultType, resultType});
2746 unsigned(getInsertSplitDimension()),
2747 bool(getInnerParallel())};
2750 FailureOr<SplitReductionResult> splitResult =
2751 (getUseScalingAlgorithm())
2754 if (failed(splitResult))
2755 return emitDefaultDefiniteFailure(target);
2757 results.
push_back(splitResult->initOrAlloc);
2759 results.
push_back(splitResult->splitLinalgOp);
2760 results.
push_back(splitResult->resultCombiningLinalgOp);
2768 void transform::TileReductionUsingForOp::build(
2779 build(builder, result,
2782 staticTileSizesAttr);
2791 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2792 if (!partialReductionOp) {
2795 "Operation should implement PartialReductionOpInterface");
2798 rewriter, partialReductionOp,
2802 return emitDefaultSilenceableFailure(target);
2803 rewriter.
replaceOp(target, result->mergeResult.replacements);
2804 for (
Value initValue : result->initialValues)
2806 for (
auto parallelTiledOp : result->tiledOps)
2808 for (
auto mergeOp : result->mergeResult.mergeOps)
2810 results.
push_back(result->loops.front());
2818 void transform::TileReductionUsingForallOp::build(
2821 ArrayAttr mapping) {
2831 build(builder, result,
2834 staticNumThreadsAttr,
2835 staticTileSizesAttr,
2848 FailureOr<linalg::ForallReductionTilingResult> result =
2850 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2851 numThreads, tileSizes, getMapping());
2853 if (failed(result)) {
2854 auto diag = emitSilenceableError() <<
"could not tile reduction";
2855 diag.attachNote(target.getLoc()) <<
"target operation";
2858 for (
Value initValue : result->initialValues)
2860 for (
auto parallelTiledOp : result->parallelTiledOps)
2862 for (
auto mergeOp : result->mergeOps)
2878 llvm::to_vector(state.getPayloadOps(getTarget()));
2880 if (!llvm::hasSingleElement(targetOps)) {
2882 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
2887 auto linalgOp = dyn_cast<LinalgOp>(target);
2888 auto tileableOp = dyn_cast<TilingInterface>(target);
2893 OpBuilder builder(linalgOp.getContext());
2895 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
2896 if (linalgOp.hasDynamicShape()) {
2897 auto diag = emitSilenceableError()
2898 <<
"cannot compute parametric tile sizes for dynamically "
2899 "shaped payload op";
2900 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
2904 FailureOr<StaticContinuousTileSizeSpecification> spec =
2908 return emitSilenceableError()
2909 <<
"failed to compute multi-size tiling sizes";
2914 for (
auto &&[tileSize, tripCount] :
2915 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2916 chunkSizes.push_back(tileSize * tripCount);
2919 return llvm::map_to_vector(values, [&](int64_t value) ->
Attribute {
2923 transformResults.
setParams(cast<OpResult>(getTileSizes()),
2924 getI64AttrsFromI64(spec->tileSizes));
2925 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
2926 getI64AttrsFromI64(chunkSizes));
2934 unsigned dimension = getDimension();
2937 builder, tileableOp, dimension, targetSize,
true);
2939 return emitSilenceableError() <<
"could not generate tile size computation";
2951 for (
auto &&[tileSize, tripCount] :
2952 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2953 splitPoint = apply(s0 * s1, {tileSize, tripCount});
2954 chunkSizes.push_back(splitPoint);
2958 return llvm::map_to_vector(values, [&](
Value value) ->
Operation * {
2963 transformResults.
set(cast<OpResult>(getTileSizes()),
2964 getDefiningOps(spec->tileSizes));
2965 transformResults.
set(cast<OpResult>(getChunkSizes()),
2966 getDefiningOps(chunkSizes));
2974 return emitOpError() <<
"expects all results type to be the same";
2980 void transform::ContinuousTileSizesOp::getEffects(
2982 if (isa<TransformParamTypeInterface>(getTileSizes().
getType()))
2998 Type &tileSizesType,
2999 Type &chunkSizesType) {
3000 FunctionType funcType;
3002 if (failed(parser.
parseType<FunctionType>(funcType)))
3005 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3006 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
3007 "argument and one result";
3009 targetType = funcType.getInput(0);
3010 tileSizesType = chunkSizesType = funcType.getResult(0);
3019 void transform::TileUsingForOp::build(
3024 return build(builder, result, loopTypes,
3028 interchange, scalableSizes);
3031 void transform::TileUsingForOp::build(
3035 build(builder, result, target,
3037 interchange, scalableSizes);
3040 void transform::TileUsingForOp::build(
3047 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3051 void transform::TileUsingForOp::build(
3063 unsigned numExpectedLoops =
3064 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3066 resultTypes.reserve(numExpectedLoops);
3067 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3068 "expected one loop type or as many as loops");
3069 if (loopTypes.size() == 1)
3070 resultTypes.append(numExpectedLoops, loopTypes[0]);
3072 llvm::append_range(resultTypes, loopTypes);
3074 if (scalableSizes.has_value())
3075 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3076 build(builder, result, target.
getType(),
3080 staticTileSizesAttr,
3082 expandedScalableSizes);
3087 return emitOpError(
"expected same number of sizes (")
3089 << getScalableSizes().size() <<
")";
3091 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3092 if (getLoops().size() != numExpectedLoops)
3093 return emitOpError(
"expected number of loops to tile (")
3094 << numExpectedLoops <<
") to match number of `loops` results ("
3095 << getLoops().size() <<
")";
3106 llvm::to_vector(state.getPayloadOps(getTarget()));
3112 if (isa<ParamType>(transformValue.getType())) {
3113 dynamicSizeProducers.push_back({});
3115 paramSizes.push_back(
3116 llvm::to_vector(llvm::map_range(params, [](
Attribute attr) {
3117 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3120 if (paramSizes.back().size() != targets.size()) {
3122 emitSilenceableError()
3123 <<
"expected as many parameter values ("
3124 << dynamicSizeProducers.back().size() <<
") as target ops ("
3125 << targets.size() <<
")";
3126 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
3132 paramSizes.push_back({});
3133 dynamicSizeProducers.push_back(
3134 llvm::to_vector(state.getPayloadOps(transformValue)));
3136 if (dynamicSizeProducers.back().size() != targets.size()) {
3138 emitSilenceableError()
3139 <<
"expected as many dynamic size-producing operations ("
3140 << dynamicSizeProducers.back().size() <<
") as target ops ("
3141 << targets.size() <<
")";
3142 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3146 for (
Operation *op : dynamicSizeProducers.back()) {
3153 emitSilenceableError() <<
"expected sizes to be produced by ops "
3154 "with a single index-type result";
3155 diag.attachNote(op->
getLoc()) <<
"size producer op";
3156 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3163 loops.resize(getLoops().size());
3164 auto scalableSizes = getScalableSizes();
3166 auto tilingInterface = dyn_cast<TilingInterface>(op);
3167 if (!tilingInterface) {
3169 emitSilenceableError()
3170 <<
"only ops implementing TilingInterface are supported";
3171 diag.attachNote(op->
getLoc()) <<
"target op";
3174 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3176 emitSilenceableError()
3177 <<
"too many tiles provided, expected at most "
3178 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
3179 << tileSizes.size();
3180 diag.attachNote(op->
getLoc()) <<
"target op";
3185 if (tileSizes.empty()) {
3194 sizes.reserve(tileSizes.size());
3195 unsigned dynamicIdx = 0;
3198 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3199 if (scalableSizes[ofrIdx]) {
3200 auto val = b.
create<arith::ConstantIndexOp>(
3201 getLoc(), cast<IntegerAttr>(attr).getInt());
3207 sizes.push_back(attr);
3214 assert((dynamicSizes.empty() ^ params.empty()) &&
3215 "expected either dynamic sizes or parameters");
3216 if (!params.empty()) {
3219 sizes.push_back(dynamicSizes[index]->getResult(0));
3227 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3228 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3229 if (failed(maybeTilingResult))
3232 rewriter.
replaceOp(op, maybeTilingResult->mergeResult.replacements);
3234 tiled.append(maybeTilingResult->tiledOps);
3236 loops[en2.index()].push_back(en2.value());
3239 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3241 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3250 results.reserve(tileSizes.size());
3251 unsigned dynamicPos = 0;
3253 for (int64_t size : tileSizes) {
3254 if (size == ShapedType::kDynamic) {
3255 results.push_back(dynamic[dynamicPos++]);
3263 void transform::TileUsingForOp::getEffects(
3275 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3279 ArrayAttr mapping) {
3280 return build(builder, result,
3288 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3292 ArrayAttr mapping) {
3302 build(builder, result,
3303 TypeRange{operationType, operationType},
3310 staticTileSizesAttr,
3314 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3318 ArrayAttr mapping) {
3319 return build(builder, result, target,
3324 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3328 ArrayAttr mapping) {
3339 build(builder, result,
3340 TypeRange{operationType, operationType},
3346 staticNumThreadsAttr,
3359 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3361 for (
auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3363 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3364 normalizedUbs.push_back(normalizedUb);
3366 return normalizedUbs;
3382 for (
auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3385 denormalizedIvs.push_back(
3388 return denormalizedIvs;
3399 scf::ForallOp loop) {
3419 auto normalizedForallOp = rewriter.
create<scf::ForallOp>(
3420 loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3423 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3425 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3430 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3431 normalizedForallOp.getRegionIterArgs().end());
3432 Block *origLoopBlock = loop.getBody();
3433 rewriter.
mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3435 rewriter.
replaceOp(loop, normalizedForallOp);
3436 return normalizedForallOp;
3441 TransformOpInterface transformOp,
Operation *target,
3446 auto tileableOp = dyn_cast<TilingInterface>(target);
3449 transformOp.emitSilenceableError()
3450 <<
"only TilingInterface ops are supported";
3451 diag.attachNote(target->
getLoc()) <<
"target op";
3457 if (!mixedNumThreads.empty()) {
3458 options.setNumThreads(mixedNumThreads);
3460 options.setTileSizes(mixedTileSizes);
3463 options.setMapping(mapping.value().getValue());
3465 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3468 if (failed(maybeTilingResult))
3469 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3471 rewriter.
replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
3473 tilingResult = *maybeTilingResult;
3475 if (mixedNumThreads.empty()) {
3476 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.
loops.front());
3479 scf::ForallOp normalizedForallOp =
3481 tilingResult.
loops.front() = normalizedForallOp;
3491 auto transformOp = cast<TransformOpInterface>(getOperation());
3500 getPackedNumThreads()
3502 state, transformOp, mixedNumThreads, getPackedNumThreads())
3504 state, transformOp, mixedNumThreads, getMixedNumThreads());
3508 status = getPackedTileSizes()
3510 state, transformOp, mixedTileSizes, getPackedTileSizes())
3512 state, transformOp, mixedTileSizes, getMixedTileSizes());
3516 for (
Operation *target : state.getPayloadOps(getTarget())) {
3519 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3520 getMapping(), tilingResult);
3521 if (!
diag.succeeded())
3523 tileOps.push_back(tilingResult.
loops.front());
3524 tiledOps.append(tilingResult.
tiledOps);
3527 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3528 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3533 void transform::TileUsingForallOp::getEffects(
3546 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3555 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3556 static_cast<int>(getPackedNumThreads() !=
Value());
3557 if (numThreadsSpec > 1)
3559 "num_threads and packed_num_threads are mutually exclusive");
3560 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3561 static_cast<int>(getPackedTileSizes() !=
Value());
3562 if (tileSizesSpec > 1)
3564 "tile_sizes and packed_tile_sizes are mutually exclusive");
3565 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3566 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
3567 "must be specified");
3575 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3577 bool vectorizePadding,
bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
3579 if (vectorizePadding) {
3581 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3585 if (vectorizeExtract) {
3587 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3591 if (flatten1DDepthwiseConv) {
3593 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3604 explicit VectorizationPattern(
MLIRContext *context,
3605 bool vectorizeExtract =
false,
3606 bool flattenConv =
false)
3608 vectorizeNDExtract(vectorizeExtract),
3609 flatten1DDepthwiseConv(flattenConv) {}
3610 LogicalResult matchAndRewrite(
Operation *op,
3614 "Unsupported Op, cannot vectorize");
3616 {}, vectorizeNDExtract,
3617 flatten1DDepthwiseConv);
3623 bool vectorizeNDExtract =
false;
3627 bool flatten1DDepthwiseConv =
false;
3632 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3637 auto diag = this->emitOpError(
"requires isolated-from-above targets");
3638 diag.attachNote(target->
getLoc()) <<
"non-isolated target";
3644 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3645 getFlatten_1dDepthwiseConv());
3647 if (!getDisableTransferPermutationMapLoweringPatterns())
3650 if (!getDisableMultiReductionToContractPatterns())
3658 vector::TransferReadOp::getCanonicalizationPatterns(
patterns, ctx);
3659 vector::TransferWriteOp::getCanonicalizationPatterns(
patterns, ctx);
3664 if (getVectorizePadding()) {
3676 return emitDefaultDefiniteFailure(target);
3690 auto targets = state.getPayloadOps(getTarget());
3691 if (std::empty(targets))
3693 auto transformOp = cast<TransformOpInterface>(getOperation());
3696 state, transformOp, getMixedVectorSizes(), vectorSizes);
3704 <<
"Unsupported Op, cannot vectorize";
3709 getVectorizeNdExtract().value_or(
false)))) {
3711 <<
"Attempted to vectorize, but failed";
3718 void transform::VectorizeOp::getEffects(
3727 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3731 if (getStaticVectorSizes().size() != getScalableSizes().size())
3732 return emitOpError(
"expected same number of vector sizes (")
3733 << getStaticVectorSizes().size() <<
") and scalable sizes ("
3734 << getScalableSizes().size() <<
")";
3743 transform::HoistRedundantVectorTransfersOp::applyToOne(
3760 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3779 auto maybeTransformed =
3782 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3785 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3788 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3791 .Case([&](linalg::Conv2DNchwFchwOp op) {
3797 if (failed(maybeTransformed))
3798 return emitDefaultSilenceableFailure(target);
3800 results.
push_back(maybeTransformed->first);
3802 results.
push_back(maybeTransformed->second);
3817 <<
"only elementwise flattening is supported";
3820 if (target.getNumLoops() <= 1) {
3827 std::iota(reassociation.begin(), reassociation.end(), 0);
3828 auto maybeFlattened =
3830 if (failed(maybeFlattened))
3832 <<
"attempted to flatten, but failed";
3833 results.
push_back(maybeFlattened->collapsedOp);
3834 rewriter.
replaceOp(target, maybeFlattened->results);
3847 auto maybeTransformed =
3849 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3852 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3858 if (failed(maybeTransformed))
3859 return emitDefaultSilenceableFailure(target);
3874 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3875 auto maybeTransformed =
3877 .Case([&](linalg::MatmulOp op) {
3880 .Case([&](linalg::BatchMatmulOp op) {
3883 .Default([&](
Operation *op) {
return failure(); });
3884 if (failed(maybeTransformed))
3894 template <
typename OpTy>
3898 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3899 tensor::ParallelInsertSliceOp>() &&
3902 if (
auto copySource =
3903 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3910 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3912 target->template getParentOfType<scf::InParallelOp>());
3915 Value extracted = rewriter.
create<tensor::ExtractSliceOp>(
3916 target.getLoc(), target.getDest(), target.getMixedOffsets(),
3917 target.getMixedSizes(), target.getMixedStrides());
3918 Value copied = rewriter
3919 .
create<linalg::CopyOp>(target.getLoc(),
3920 target.getSource(), extracted)
3925 target, copied, target.getDest(), target.getMixedOffsets(),
3926 target.getMixedSizes(), target.getMixedStrides());
3928 results.
push_back(copied.getDefiningOp());
3938 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3939 return doit(rewriter, target, results, state);
3940 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3941 return doit(rewriter, target, results, state);
3944 emitSilenceableError()
3945 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3946 diag.attachNote(targetOp->
getLoc()) <<
"target op";
3959 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3961 emitSilenceableError()
3962 <<
"only linalg.copy and tensor.pad target ops are supported";
3963 diag.attachNote(target->
getLoc()) <<
"target op";
3966 assert(target->
getNumResults() == 1 &&
"expected single result");
3968 if (!resultShapedType.hasStaticShape()) {
3970 emitSilenceableError()
3971 <<
"only statically sized ops of rank <= 3 are supported";
3972 diag.attachNote(target->
getLoc()) <<
"target op";
3977 int64_t desiredBitAlignment = getDesiredBitAlignment();
3978 int64_t eltBitwidth =
3979 resultShapedType.getElementType().getIntOrFloatBitWidth();
3980 if (desiredBitAlignment % eltBitwidth != 0) {
3981 desiredBitAlignment = eltBitwidth;
3986 getTotalNumThreads(),
3987 desiredBitAlignment,
3988 resultShapedType.getShape(),
3991 resultShapedType.getElementType().getIntOrFloatBitWidth());
3992 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3994 emitSilenceableError()
3995 <<
"too few threads to map copy op to threads on the most minor "
3996 "dimension, given alignment and vector size constraints, try "
3997 "smaller tile size of mapping to more threads";
3998 diag.attachNote(target->
getLoc()) <<
"target op";
4014 if (!
diag.succeeded())
4018 for (
auto op : tilingResult.
tiledOps)
4032 FailureOr<Operation *> maybeTransformed = failure();
4034 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4039 .Default([&](
Operation *op) {
return false; });
4042 return emitSilenceableError()
4043 <<
"this operation is not supported to convert to Winograd Conv2D";
4046 if (failed(maybeTransformed)) {
4047 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
4059 FailureOr<Operation *> maybeTransformed = failure();
4062 .Case([&](linalg::WinogradFilterTransformOp op) {
4066 .Case([&](linalg::WinogradInputTransformOp op) {
4070 .Case([&](linalg::WinogradOutputTransformOp op) {
4074 .Default([&](
Operation *op) {
return false; });
4078 emitSilenceableError()
4079 <<
"this operation is not supported to decompose into other operations";
4080 diag.attachNote(target->
getLoc()) <<
"target op";
4084 if (failed(maybeTransformed)) {
4086 emitSilenceableError() <<
"decompose Winograd operations failed";
4087 diag.attachNote(target->
getLoc()) <<
"target op";
4095 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4097 #define GET_OP_CLASSES
4098 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
This class represents a saved insertion point.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getOpResult(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
uint64_t getM(LevelType lt)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
@ MaterializeInDestination
Transformation to drop unit-extent dimensions from linalg.generic operations.
Vectorization pattern for memref::CopyOp.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Match and rewrite for the pattern:
Match and rewrite for the pattern:
@ BufferizationMaterializeInDestination
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.