44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/ScopeExit.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/Debug.h"
48 #include <type_traits>
54 #define DEBUG_TYPE "linalg-transforms"
55 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56 #define DBGSNL() (llvm::dbgs() << "\n")
57 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
64 template <
typename PatternTy,
typename... Args>
67 using OpTy =
typename llvm::function_traits<
68 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
69 auto op = dyn_cast<OpTy>(operation);
74 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
79 explicit TrivialPatternRewriter(
MLIRContext *context)
82 TrivialPatternRewriter rewriter(operation->
getContext());
83 rewriter.setInsertionPoint(operation);
84 auto result = pattern.returningMatchAndRewrite(op, rewriter);
87 return cast<LinalgOp>(result->getOperation());
97 if (
auto attr = dyn_cast<Attribute>(ofr)) {
98 if (!isa<IntegerAttr>(attr))
99 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
100 result.push_back(ofr);
104 Value transformValue = cast<Value>(ofr);
105 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
107 if (params.size() != 1)
108 return transformOp.emitDefiniteFailure()
109 <<
"requires exactly one parameter associated";
110 result.push_back(params[0]);
114 auto payloadOps = state.getPayloadOps(transformValue);
115 if (!llvm::hasSingleElement(payloadOps)) {
117 transformOp.emitSilenceableError()
118 <<
"handle must be mapped to exactly one payload op";
120 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
127 transformOp.emitSilenceableError()
128 <<
"payload op must have exactly 1 index result";
148 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
150 for (
auto param : params) {
151 if (!isa<IntegerAttr>(param))
152 return transformOp.emitDefiniteFailure()
153 <<
"expected the parameter to be associated with an integer "
155 result.push_back(param);
160 for (
Operation *op : state.getPayloadOps(packedHandle)) {
161 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
163 transformOp.emitSilenceableError()
164 <<
"payload op must have exactly 1 index result";
165 diag.attachNote(op->getLoc())
166 <<
"has " << op->getNumResults() <<
" results";
169 result.push_back(op->getResult(0));
183 if (
auto attr = dyn_cast<Attribute>(paramOrHandle)) {
184 reified.push_back(cast<IntegerAttr>(attr).getInt());
186 }
else if (isa<ParamType>(cast<Value>(paramOrHandle).
getType())) {
188 if (params.size() != 1)
189 return transformOp.emitSilenceableError() <<
"expected a single param";
191 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
195 Value handle = cast<Value>(paramOrHandle);
196 if (!isa<TransformHandleTypeInterface>(handle.
getType()))
197 return transformOp.emitSilenceableError() <<
"unexpected value handle";
198 auto payload = state.getPayloadOps(handle);
199 if (!llvm::hasSingleElement(payload))
200 return transformOp.emitSilenceableError()
201 <<
"requires param or handle that is mapped to 1 payload op";
203 Operation *paramOrHandlePayloadOp = *payload.begin();
206 return transformOp.emitSilenceableError()
207 <<
"requires param or handle to be result of op with 1 index "
213 return transformOp.emitSilenceableError()
214 <<
"requires param or handle to be the result of a constant like "
217 reified.push_back(attr.getInt());
226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
231 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
236 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
241 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
247 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
250 options.rankReductionStrategy =
255 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
260 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
265 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
270 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
275 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
284 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
289 resultTypes.push_back(b.
getType<transform::AnyValueType>());
290 resultTypes.push_back(b.
getType<transform::AnyOpType>());
291 return build(b, result,
297 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
300 int64_t memorySpace) {
302 resultTypes.push_back(b.
getType<transform::AnyValueType>());
303 resultTypes.push_back(b.
getType<transform::AnyOpType>());
304 return build(b, result,
320 void notifyOperationInserted(
Operation *op,
322 ForwardingListener::notifyOperationInserted(op, previous);
324 if (previous.
isSet())
326 auto inserted = newOps.insert(op);
328 assert(inserted.second &&
"expected newly created op");
331 void notifyOperationErased(
Operation *op)
override {
332 ForwardingListener::notifyOperationErased(op);
346 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
347 NewOpsListener newOpsListener(previousListener);
351 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
354 }
else if (getMemcpyOp() ==
"memref.copy") {
357 }
else if (getMemcpyOp() ==
"linalg.copy") {
361 llvm_unreachable(
"invalid memcpy op");
363 if (getAllocOp() ==
"memref.alloc") {
366 }
else if (getAllocOp() ==
"memref.alloca") {
370 llvm_unreachable(
"invalid alloc op");
372 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
373 options.emitDealloc = getEmitDealloc();
377 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
379 for (
Operation *op : state.getPayloadOps(getTarget())) {
384 <<
"failed to bufferize operation";
385 diag.attachNote(op->
getLoc()) <<
"target payload op";
388 allocatedBuffers.push_back(buffer);
392 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
393 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
397 void transform::BufferizeToAllocationOp::getEffects(
399 if (getBufferizeDestinationOnly()) {
411 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
412 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
413 return emitOpError() <<
"unsupported memcpy op";
414 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
415 return emitOpError() <<
"unsupported alloc op";
428 #define DOWNSCALE(trans) \
430 FailureOr<LinalgOp> res = tryApply<trans>(target); \
431 if (succeeded(res)) { \
432 results.push_back(*res); \
433 return DiagnosedSilenceableFailure::success(); \
437 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
438 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
451 #undef DOWNSCALE_NORMAL
452 #undef DOWNSCALE_CALL
454 return emitDefaultSilenceableFailure(target);
468 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
469 if (!decomposableOp) {
471 "payload is not a decomposable op"));
472 return emitDefaultSilenceableFailure(target);
475 FailureOr<SmallVector<Value>> maybeNewResults =
476 decomposableOp.decomposeOperation(rewriter);
477 if (failed(maybeNewResults))
478 return emitDefaultSilenceableFailure(target);
480 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
481 for (
Value val : *maybeNewResults) {
482 Operation *definition = val.getDefiningOp();
493 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
500 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
504 options.allowReturnAllocsFromLoops =
true;
506 for (
Operation *target : state.getPayloadOps(getTarget())) {
510 <<
"failed to analyze op";
512 rewriter, target, state)))
514 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
525 template <
typename Range>
529 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
535 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
536 if (!tilingInterfaceOp)
537 return transformOp->
emitError(
"only TilingInterface ops are supported");
540 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
541 applyFn(tilingInterfaceOp);
542 if (failed(tiledResults))
547 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
548 for (
Operation *toReplace : opsToReplace) {
549 for (
OpResult res : toReplace->getResults())
550 if (
auto replacement = tiledResults->replacements.lookup(res))
552 if (toReplace->use_empty()) {
558 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
559 assert(tiledResults->loops.size() == numLoops &&
560 "Mismatched number of loops, tile and fuse transform should have "
562 for (
unsigned int i = 0; i < numLoops; ++i)
563 loopOps[i].push_back(tiledResults->loops[i]);
566 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
567 for (
unsigned int i = 0; i < numLoops; ++i)
568 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
578 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
580 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
586 tilingOptions = tilingOptions.
setTileSizes(tileSizesOfr);
590 if (getApplyCleanup()) {
593 tensor::ExtractSliceOp::getCanonicalizationPatterns(
patterns, context);
600 rewriter, getOperation(), state.getPayloadOps(getTarget()),
601 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
602 [&](TilingInterface tilingInterfaceOp)
603 -> FailureOr<scf::SCFTileAndFuseResult> {
604 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
613 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
614 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
615 if (!std::is_permutation(sequence.begin(), sequence.end(),
616 permutation.begin(), permutation.end())) {
617 return emitOpError() <<
"expects interchange to be a permutation, found "
618 << getTileInterchange();
622 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
623 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
624 if (numExpectedLoops != getNumResults() - 1)
625 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
634 void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
637 Value containingOp) {
640 result.
addTypes({resultType, resultType});
656 (domInfo.
dominates(containingOp, user))) {
657 dominatedUsers.insert(user);
660 if (dominatedUsers.empty())
664 auto forallOp = cast<scf::ForallOp>(containingOp);
670 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
675 newOuts.push_back(outputs[resultNumber]);
678 auto newforallOp = rewriter.
create<scf::ForallOp>(
679 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
680 forallOp.getMixedStep(), newOuts, forallOp.getMapping());
682 newforallOp.getRegion().takeBody(forallOp.getRegion());
687 newforallOp.getBody()->addArgument(newOuts.back().getType(),
688 newOuts.back().getLoc());
689 auto bbArgs = newforallOp.getBody()->getArguments();
697 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
699 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; }));
700 Operation *firstYieldOp = yieldingOps.front();
703 Value dst = newforallOp.getRegionIterArgs().back();
705 rewriter.
create<tensor::ParallelInsertSliceOp>(firstYieldOp->
getLoc(), src,
706 dst, offsets, sizes, strides);
710 newforallOp->getResult(result.index()));
713 newforallOp->getResults().back(),
716 return dominatedUsers.contains(user);
727 static std::tuple<SmallVector<Operation *>,
Operation *>
730 LLVM_DEBUG(
DBGS() <<
"Try to fuse a direct extract use\n");
731 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
732 if (!tileableProducer) {
734 <<
"producer is not a TileableInterface: " << *producerOp;
741 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
742 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
743 return sliceOp && containingOp->isProperAncestor(sliceOp);
747 if (it == tileableProducer->getUsers().end()) {
748 diag.attachNote(tileableProducer->getLoc())
749 <<
"could not find fusion opportunity for: " << *tileableProducer;
752 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
759 int64_t resultNumber =
760 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
761 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
766 FailureOr<TilingResult> tileAndFuseResult =
767 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
770 if (failed(tileAndFuseResult)) {
771 diag.attachNote(tileableProducer->getLoc())
772 <<
"failed to tile producer op: " << *tileableProducer;
777 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
778 LLVM_DEBUG(
DBGS() <<
"tiledProducer: " << *tiledOp <<
"\n");
783 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
784 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
785 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
786 if (failed(maybeRankReduced)) {
788 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
789 << tileAndFuseResult->tiledValues[0]
790 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
793 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
797 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
798 resultNumber, offsets, sizes);
800 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
813 LLVM_DEBUG(
DBGS() <<
"Try to fuse an extract use through block argument\n");
815 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
816 if (!tileableProducer) {
818 <<
"producer is not a TileableInterface: " << *producerOp;
823 scf::ForallOp forallOp;
824 auto itProducerUses =
825 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
826 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
830 if (!forallOp || forallOp != containingOp) {
831 diag.attachNote(tileableProducer->getLoc())
832 <<
"could not find a use by the containing op: " << *tileableProducer;
847 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
848 return sliceOp && containingOp->isProperAncestor(sliceOp);
852 if (itBBArgUsers == bbArg.
getUsers().end()) {
854 <<
"could not find fusion opportunity for bbArg: " << bbArg;
857 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
865 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
866 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
871 rewriter, tileableProducer->getLoc(), tileableProducer,
872 destinationTensors))) {
873 diag.attachNote(tileableProducer->getLoc())
874 <<
"failed to get destination tensors for: " << *tileableProducer;
879 bvm.
map(destinationTensors[resultNumber], bbArg);
880 auto tileableProducerClone =
881 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
883 llvm::make_scope_exit([&]() { rewriter.
eraseOp(tileableProducerClone); });
886 FailureOr<TilingResult> tileAndFuseResult =
887 tileableProducerClone.generateResultTileValue(
888 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
889 sliceOpToTile.getMixedSizes());
890 if (failed(tileAndFuseResult)) {
891 diag.attachNote(tileableProducer->getLoc())
892 <<
"failed to tile producer op: " << *tileableProducer;
897 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
898 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
899 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
900 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
901 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
906 destinationTensors.front());
909 return tileAndFuseResult->tiledOps;
915 LLVM_DEBUG(
DBGS() <<
"Try to fuse an use by cloning\n");
920 for (
OpOperand &use : result.getUses()) {
922 uses.push_back(&use);
927 if (containingOp == use.getOwner()) {
929 <<
"producer op use by containing op cannot be fused by cloning";
937 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
946 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
947 "Parallel insert slice is not a valid clone destination");
948 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
949 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
953 fusedOp = rewriter.
clone(*producerOp);
955 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
960 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
970 auto producerOps = state.getPayloadOps(getProducerOp());
971 auto containingOps = state.getPayloadOps(getContainingOp());
972 if (!llvm::hasSingleElement(containingOps)) {
974 <<
"requires exactly one containing_op handle (got "
975 << llvm::range_size(containingOps) <<
")";
977 Operation *containingOp = *containingOps.begin();
980 if (std::empty(producerOps)) {
982 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
990 auto getNextProducer = [&]() -> FailureOr<Operation *> {
991 for (
const auto &it :
enumerate(remainingProducers)) {
994 int64_t numUsesInContainingOp =
996 return containingOp->isAncestor(op);
1001 if (numUsesInContainingOp > 0) {
1002 if (numUsesInContainingOp == 1)
1003 remainingProducers.erase(remainingProducers.begin() + it.index());
1010 while (!remainingProducers.empty()) {
1011 auto nextProducer = getNextProducer();
1012 if (failed(nextProducer)) {
1014 <<
"could not find next producer to fuse into container";
1015 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
1023 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
1030 auto [tiledOps, newContainingOp] =
1032 if (!tiledOps.empty()) {
1033 LLVM_DEBUG(
DBGS() <<
"\nFused a direct extract use\n" << *containingOp);
1034 fusedOps.append(tiledOps);
1035 if (newContainingOp) {
1043 LogicalResult replacementStatus =
1046 (void)replacementStatus;
1047 assert(succeeded(replacementStatus) &&
1048 "unable to update transform state mapping");
1049 rewriter.
eraseOp(containingOp);
1050 containingOp = newContainingOp;
1057 rewriter,
diag, producerOp, containingOp);
1058 if (!tiledContainingOpOperand.empty()) {
1059 LLVM_DEBUG(
DBGS() <<
"\nFused an extract use through block argument\n"
1061 fusedOps.append(tiledContainingOpOperand);
1068 LLVM_DEBUG(
DBGS() <<
"\nFused an use by cloning\n" << *containingOp);
1069 fusedOps.push_back(cloned);
1075 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1076 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1080 void transform::FuseIntoContainingOp::getEffects(
1098 if (isa<GenericOp>(target)) {
1104 if (succeeded(
generic)) {
1105 results.
push_back(generic->getOperation());
1108 return emitDefaultSilenceableFailure(target);
1121 if (!isa<GenericOp>(target)) {
1126 FailureOr<LinalgOp> named =
1128 if (succeeded(named)) {
1129 results.
push_back(named->getOperation());
1132 return emitDefaultSilenceableFailure(target);
1146 if (interchangeVector.empty()) {
1151 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1152 if (interchangeVector.size() != numLoops) {
1153 return emitSilenceableError()
1154 << getIteratorInterchangeAttrName() <<
" has length ("
1155 << interchangeVector.size()
1156 <<
") different from the number of loops in the target operation ("
1169 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1170 if (!std::is_permutation(sequence.begin(), sequence.end(),
1171 permutation.begin(), permutation.end())) {
1172 return emitOpError()
1173 <<
"expects iterator_interchange to be a permutation, found "
1174 << getIteratorInterchange();
1188 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1189 FailureOr<LowerPackResult> res =
1190 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1193 <<
"cannot lower to pad + expand + transpose";
1196 transformResults.
push_back(res->expandShapeOp);
1197 transformResults.
push_back(res->transposeOp);
1210 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1211 FailureOr<LowerUnPackOpResult> res =
1212 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1215 emitSilenceableError()
1216 <<
"cannot lower to transpose + collapse + extract";
1217 diag.attachNote(target->getLoc()) <<
"target payload op";
1220 transformResults.
push_back(res->emptyOp);
1221 transformResults.
push_back(res->transposeOp);
1222 transformResults.
push_back(res->collapseShapeOp);
1223 transformResults.
push_back(res->extractSliceOp);
1253 if (getOps().has_value())
1254 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1255 getOps()->getAsValueRange<StringAttr>().end());
1257 auto payloadOps = state.getPayloadOps(getTarget());
1258 if (!llvm::hasSingleElement(payloadOps)) {
1263 bool incorrectNumOperandTypes =
false;
1270 if (getInterface().has_value()) {
1271 auto iface = getInterface().value();
1272 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1275 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1276 !isa<TilingInterface>(op))
1278 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1279 !isa<LoopLikeOpInterface>(op))
1284 if (getOpAttrs().has_value()) {
1285 DictionaryAttr opAttrs = getOpAttrs().value();
1287 if (attr.getName() == getInterfaceAttrName() ||
1288 attr.getName() == getOpsAttrName())
1290 if (!op->
hasAttr(attr.getName()))
1292 if (op->
getAttr(attr.getName()) != attr.getValue())
1297 if (getFilterResultType().has_value()) {
1298 Type t = getFilterResultType().value();
1303 if (getFilterOperandTypes().has_value()) {
1304 mlir::ArrayAttr types = getFilterOperandTypes().value();
1307 if (types.size() == 1) {
1310 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1311 Type t = cast<::mlir::Type>(typeattr.getValue());
1313 [&](
Type operandType) { return operandType == t; }))
1318 if (types.size() != operandTypes.size()) {
1319 incorrectNumOperandTypes =
true;
1323 for (
auto [attr, operandType] :
1324 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1325 auto typeattr = cast<mlir::TypeAttr>(attr);
1326 Type type = cast<::mlir::Type>(typeattr.getValue());
1328 if (type != operandType)
1339 (*payloadOps.begin())->
walk(matchFun);
1340 if (incorrectNumOperandTypes)
1342 "type, then it must contain as much types as "
1343 "the number of operands in the target ops");
1344 results.
set(cast<OpResult>(getResult()), res);
1359 Type &targetType,
Type &lowSizeType,
1361 Type &splitPointType) {
1362 FunctionType funcType;
1364 if (failed(parser.
parseType<FunctionType>(funcType)))
1367 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1368 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1369 "argument and one result";
1371 targetType = funcType.getInput(0);
1372 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1380 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1381 if (target.hasDynamicShape()) {
1382 auto diag = emitSilenceableError()
1383 <<
"cannot compute parametric tile sizes for dynamically "
1384 "shaped payload op";
1385 diag.attachNote(target->getLoc()) <<
"payload op";
1390 target, getDimension(), getTargetSize(), getDivisor());
1392 return emitSilenceableError()
1393 <<
"failed to compute multi-size tiling sizes";
1396 Builder builder(target.getContext());
1397 results.
assign(llvm::map_range(
1399 spec->lowTileSize * spec->lowTripCount}),
1400 [&builder,
this](int64_t value) {
1412 builder, target, getDimension(), targetSize, divisor);
1414 return emitSilenceableError() <<
"could not generate tile size computation";
1421 {spec->lowTileSize, spec->lowTripCount});
1422 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1423 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1424 assert(lowTileSize && highTileSize && splitPoint &&
1425 "tile sizes are not produced by operations");
1433 void transform::MultiTileSizesOp::getEffects(
1437 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1446 return emitOpError() <<
"expects all results type to be the same";
1466 builder.
getContext(), GenericOp::getOperationName());
1467 build(builder, result,
1476 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1483 auto targetOps = state.getPayloadOps(getTarget());
1485 if (std::empty(targetOps)) {
1486 transformResults.
set(cast<OpResult>(getPackedOp()),
1491 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1492 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1493 return emitSilenceableError()
1494 <<
"requires target to map to exactly 1 LinalgOp (got "
1495 << llvm::range_size(targetOps) <<
")";
1498 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1499 return emitSilenceableError()
1500 <<
"requires number of packed sizes match the number of loops ("
1501 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1508 state, *
this, packedSizes, getMixedPackedSizes());
1511 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1512 if (failed(maybeResult))
1515 transformResults.
set(cast<OpResult>(getPackedOp()),
1516 {maybeResult->packedLinalgOp.getOperation()});
1520 void transform::PackOp::getEffects(
1534 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1535 <<
" is not a valid permutation";
1538 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1539 for (
auto [s, nmo] :
1540 llvm::zip_equal(getMixedMatmulPackedSizes(),
1541 getMatmulPaddedSizesNextMultipleOf())) {
1544 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1545 return emitOpError() <<
"at most one of the packed_size and the "
1546 "padded_sizes_next_multiple_of can be nonzero "
1547 "for the matmul strategy";
1559 for (
Operation *op : state.getPayloadOps(getTarget())) {
1560 auto linalgOp = dyn_cast<LinalgOp>(op);
1571 getMixedMatmulPackedSizes(),
1573 getMatmulPaddedSizesNextMultipleOf(),
1574 getMatmulInnerDimsOrder());
1575 if (succeeded(packResult)) {
1576 results.push_back(packResult->packedLinalgOp);
1579 results.push_back(linalgOp);
1581 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1587 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1591 void transform::PackGreedilyOp::getEffects(
1605 return emitOpError() << getInnerPermAttrName()
1606 <<
" is not a valid permutation";
1609 return emitOpError() << getOuterPermAttrName()
1610 <<
" is not a valid permutation";
1612 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1613 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1614 <<
" or " << getOuterPermAttrName()
1615 <<
" must be specified";
1621 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1631 template <
typename RelayoutOpTy>
1634 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1636 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1637 "applies to only pack or unpack operations");
1638 if (!op || permutation.empty())
1640 size_t innerRank = op.getInnerDimsPos().size();
1641 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1645 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1646 return permutation.size() == op.getSourceRank() &&
1649 return permutation.size() == op.getDestRank() &&
1657 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1658 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1660 if (std::empty(packOrUnpackOps)) {
1661 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1662 transformResults.
set(cast<OpResult>(getPackOp()), {});
1663 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1669 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1670 !llvm::hasSingleElement(linalgOps)) {
1671 return emitSilenceableError()
1672 <<
"requires target to map to exactly 1 "
1673 "packing op and 1 packed op ("
1674 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
1675 << llvm::range_size(linalgOps) <<
")";
1679 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1680 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1681 if ((!packOp && !unPackOp)) {
1682 return emitSilenceableError() <<
"requires target to map to a "
1683 "linalg.pack or linalg.unpack";
1685 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1686 if (!linalgOpTarget)
1687 return emitSilenceableError() <<
"requires a LinalgOp target";
1691 if (packOp && packOp.getResult().hasOneUse())
1692 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1694 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1695 if (linalgOp != linalgOpTarget) {
1697 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
1698 : StringLiteral{
"not produced by the LinalgOp target"};
1699 return emitSilenceableError() << errorMsg;
1705 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
1706 OpOperand *packUse = linalgOp.getDpsInitOperand(
1707 cast<OpResult>(unPackOp.getSource()).getResultNumber());
1709 if (!packOp || !packOp.getResult().hasOneUse())
1710 return emitSilenceableError() <<
"could not find matching pack op";
1714 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1716 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1717 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1718 ? StringLiteral{
"invalid outer_perm"}
1719 : StringLiteral{
"invalid inner_perm"};
1723 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1724 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
1730 assert(packOp && linalgOp &&
"unexpected null op");
1734 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1736 assert(succeeded(res) &&
"unexpected packTranspose failure");
1739 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1740 transformResults.
set(cast<OpResult>(getPackedOp()),
1741 {res->transposedLinalgOp});
1743 transformResults.
set(cast<OpResult>(getUnPackOp()),
1744 {res->transposedUnPackOp});
1746 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1761 StringRef copyBackOp) {
1771 (padToMultipleOf.empty()
1773 : b.getDenseI64ArrayAttr(padToMultipleOf)),
1774 b.getI64ArrayAttr(nofoldFlags),
1775 b.getArrayAttr(transposePaddings),
1776 b.getStringAttr(copyBackOp));
1784 StringRef copyBackOp) {
1789 staticPadToMultipleOf);
1796 dynamicPadToMultipleOf,
1797 staticPadToMultipleOf,
1803 void PadOp::getEffects(
1813 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1820 auto transformOp = cast<TransformOpInterface>(getOperation());
1823 for (
Operation *target : state.getPayloadOps(getTarget())) {
1824 auto linalgTarget = dyn_cast<LinalgOp>(target);
1825 if (!linalgTarget) {
1826 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
1827 diag.attachNote(target->
getLoc()) <<
"target op";
1833 for (int64_t packPadding :
1834 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1835 nofoldFlags.push_back(
static_cast<bool>(packPadding));
1839 for (
auto const &it :
1840 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1841 auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1843 emitOpError(
"expects padding values to be typed attributes");
1848 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
1852 if (!parsedAttr || parsedAttr.getType() != elementType) {
1853 auto diag = this->emitOpError(
"expects a padding that parses to ")
1854 << elementType <<
", got " << std::get<0>(it);
1855 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
1858 paddingValues.push_back(parsedAttr);
1862 if (attr.getType() != elementType) {
1863 auto diag = this->emitOpError(
"expects a padding value of type ")
1864 << elementType <<
", got " << attr;
1865 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
1868 paddingValues.push_back(attr);
1873 for (
Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1874 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1875 cast<ArrayAttr>(transposeVector)));
1880 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1884 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
1887 if (padToMultipleOf.empty())
1891 options.padToMultipleOf = padToMultipleOf;
1892 options.paddingValues = paddingValues;
1893 options.nofoldFlags = nofoldFlags;
1894 if (getCopyBackOp() ==
1895 bufferization::MaterializeInDestinationOp::getOperationName()) {
1898 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1900 }
else if (getCopyBackOp() == kCopyOpNone) {
1903 llvm_unreachable(
"unsupported copy_back op");
1909 replacements, newPadOps))) {
1910 auto diag = emitSilenceableError() <<
"failed to pad op";
1911 diag.attachNote(target->
getLoc()) <<
"target op";
1920 rewriter.
replaceOp(linalgTarget, replacements);
1921 paddedOps.push_back(paddedOp);
1922 padOps.append(newPadOps.begin(), newPadOps.end());
1924 for (
Value v : replacements) {
1925 Operation *copyBackOp = v.getDefiningOp();
1926 if (!llvm::is_contained(copyBackOps, copyBackOp))
1927 copyBackOps.push_back(copyBackOp);
1932 results.
set(cast<OpResult>(getPadded()), paddedOps);
1933 results.
set(cast<OpResult>(getPad()), padOps);
1934 results.
set(cast<OpResult>(getCopy()), copyBackOps);
1940 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
1941 if (any_of(nofoldFlags, [](int64_t packPadding) {
1942 return packPadding != 0 && packPadding != 1;
1944 return emitOpError()
1945 <<
"expects nofold_flags to contain booleans (0/1), found "
1946 << getNofoldFlags();
1950 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1951 if (any_of(paddingDimensions,
1952 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
1953 return emitOpError() <<
"expects padding_dimensions to contain positive "
1955 << getPaddingDimensions();
1957 if (!getMixedPadToMultipleOf().empty()) {
1958 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
1959 return emitOpError() <<
"expects as many multiples as padding_dimensions";
1962 ArrayAttr transposes = getTransposePaddings();
1965 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
1966 if (!std::is_permutation(sequence.begin(), sequence.end(),
1968 return emitOpError()
1969 <<
"expects transpose_paddings to be a permutation, found "
1973 if (getCopyBackOp() !=
1974 bufferization::MaterializeInDestinationOp::getOperationName() &&
1975 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1976 getCopyBackOp() != kCopyOpNone)
1977 return emitOpError() <<
"invalid copy_back_op";
1989 auto targetOps = state.getPayloadOps(getTarget());
1990 auto loopOps = state.getPayloadOps(getLoop());
1991 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1993 <<
"requires exactly one target and one loop handle (got "
1994 << llvm::range_size(targetOps) <<
" and "
1995 << llvm::range_size(loopOps) <<
")";
1998 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1999 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2000 if (!padOp || !loopOp)
2003 FailureOr<linalg::detail::PackingResult> result =
2009 if (result->clonedLoopIvs.empty()) {
2010 transformResults.
set(cast<OpResult>(getPackingLoop()),
2011 {result->hoistedPadOp.getOperation()});
2014 auto outerPackedLoop =
2016 transformResults.
set(cast<OpResult>(getPackingLoop()),
2017 {outerPackedLoop.getOperation()});
2023 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2024 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
2026 return emitOpError() <<
"expects transpose to be a permutation, found "
2032 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2042 tensor::PadOp target,
2045 tensor::PadOp hoistedPadOp;
2047 FailureOr<Value> result =
2049 hoistedPadOp, transposeOps);
2050 if (succeeded(result)) {
2060 return emitDefaultSilenceableFailure(target);
2065 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2066 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
2068 return emitOpError() <<
"expects transpose to be a permutation, found "
2084 if (!getOperandsToPromote().empty())
2086 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2087 if (getUseFullTilesByDefault())
2089 getUseFullTilesByDefault());
2091 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2092 if (!getUseFullTileBuffers().empty())
2094 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2095 if (getAlignment().has_value())
2096 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2097 if (getMemorySpace().has_value())
2098 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2100 if (getMapping().has_value()) {
2102 auto mapping = *getMapping();
2103 if (mapping.size() > 1)
2104 return emitDefaultDefiniteFailure(target);
2106 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2108 if (addressSpace.getAddressSpace() ==
2109 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2116 }
else if (addressSpace.getAddressSpace() ==
2117 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2125 return emitDefaultDefiniteFailure(target);
2130 return emitDefaultDefiniteFailure(target);
2133 FailureOr<LinalgOp> res =
promoteSubViews(rewriter, target, promotionOptions);
2135 return emitDefaultDefiniteFailure(target);
2148 auto payload = state.getPayloadOps(getTarget());
2152 if (target->getNumOperands() > 0)
2155 target->getNumRegions() > 0)
2157 <<
"expected target that is isolated from above";
2161 Operation *pattern = &getBodyRegion().front().front();
2164 if (getOperation()->isAncestor(target))
2169 replacements.push_back(replacement);
2171 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2175 void transform::ReplaceOp::getEffects(
2183 if (!getBodyRegion().hasOneBlock())
2184 return emitOpError() <<
"expected one block";
2185 if (std::distance(getBodyRegion().front().begin(),
2186 getBodyRegion().front().end()) != 1)
2187 return emitOpError() <<
"expected one operation in block";
2188 Operation *replacement = &getBodyRegion().front().front();
2191 <<
"expected replacement without operands";
2195 <<
"expect op that is isolated from above";
2213 target.createFlatListOfOperandDims(b, loc);
2214 AffineMap map = target.getShapesToLoopsMap();
2230 FailureOr<scf::SCFTilingResult> maybeTilingResult =
tileUsingSCF(
2231 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2232 if (failed(maybeTilingResult))
2233 return emitDefaultDefiniteFailure(target);
2235 if (target->getNumResults())
2236 rewriter.
replaceOp(target, maybeTilingResult->mergeResult.replacements);
2240 results.
reserve(maybeTilingResult->tiledOps.size());
2241 for (
Operation *tiled : maybeTilingResult->tiledOps)
2255 for (
Operation *target : state.getPayloadOps(getTarget())) {
2256 auto tilingOp = dyn_cast<TilingInterface>(*target);
2259 emitSilenceableError()
2260 <<
"expected the payload to implement TilingInterface";
2261 diag.attachNote(target->getLoc()) <<
"payload op";
2265 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2267 if (failed(generatedLoops))
2268 return emitDefaultDefiniteFailure(target);
2269 for (scf::ForOp &loop : *generatedLoops) {
2270 loops.push_back(loop.getOperation());
2274 results.
set(cast<OpResult>(getResult()), loops);
2283 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2289 FailureOr<Operation *> maybeResult =
2291 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2292 [&rewriter](
auto op) {
2295 if (failed(maybeResult))
2296 return emitDefaultSilenceableFailure(target);
2310 llvm::to_vector(state.getPayloadOps(getTarget()));
2312 bool isMultiwaySplit = getMultiway();
2314 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2316 <<
"requires exactly one target when "
2317 "multiway split is enabled (got "
2318 << llvm::range_size(payload) <<
")";
2323 if (!isMultiwaySplit)
2324 chunkSizes.reserve(payload.size());
2326 if (getDynamicChunkSizes()) {
2328 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2329 chunkSizes = llvm::to_vector(llvm::map_range(
2330 state.getPayloadOps(getDynamicChunkSizes()), [&](
Operation *op) {
2333 diag = emitSilenceableError()
2334 <<
"expected dynamic split point handle to point to a "
2335 "single-result index-typed op";
2336 diag.attachNote(op->getLoc()) <<
"dynamic split point";
2341 chunkSizes = llvm::to_vector(
2342 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2345 if (
diag.isSilenceableFailure())
2350 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2352 <<
"expected the dynamic split point handle to point to as "
2354 << chunkSizes.size() <<
") as the target handle ("
2355 << payload.size() <<
")";
2358 chunkSizes.resize(payload.size(),
2362 auto checkStructuredOpAndDimensions =
2365 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2366 diag.attachNote(loc) <<
"target op";
2370 if (getDimension() >= linalgOp.getNumLoops()) {
2371 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2372 <<
" does not exist in target op";
2373 diag.attachNote(loc) <<
"target op";
2379 auto checkFailureInSplitting =
2383 diag.attachNote(loc) <<
"target op";
2390 if (isMultiwaySplit) {
2393 TilingInterface head, tail;
2396 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2400 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2401 if (
diag.isSilenceableFailure())
2407 target = tail.getOperation();
2412 linalgOp = cast<LinalgOp>(target);
2417 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2418 getDimension(), chunkSize);
2422 checkFailureInSplitting(!head && !tail, loc);
2423 if (
diag.isDefiniteFailure())
2426 opList.push_back(head.getOperation());
2431 opList.push_back(tail.getOperation());
2437 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2440 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2442 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2444 if (
diag.isSilenceableFailure())
2448 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2449 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2450 getDimension(), std::get<1>(pair));
2454 checkFailureInSplitting(!first.back() && !second.back(), loc);
2459 if (!second.back()) {
2460 noSecondPart = target;
2465 if (second.size() != first.size() && !second.empty()) {
2466 auto diag = emitSilenceableError()
2467 <<
"splitting does not produce the second part for a subset "
2470 <<
"expected splitting to produce the second part of all "
2471 "or none of the targets";
2473 <<
"first target with no second part";
2477 opList.append(first);
2479 opList.append(second);
2481 results.
set(cast<OpResult>(getSplitList()), opList);
2485 void SplitOp::getEffects(
2488 if (getDynamicChunkSizes())
2496 IntegerAttr staticChunkSizes;
2502 if (!dynamicPointParseResult.
has_value()) {
2503 int64_t staticChunkSizesValue;
2504 if (failed(parser.
parseInteger(staticChunkSizesValue)))
2517 if (dynamicPointParseResult.
has_value()) {
2518 Type ChunkSizesType;
2519 if (failed(*dynamicPointParseResult) || parser.
parseComma() ||
2531 SplitOp::getStaticChunkSizesAttrName(result.
name).getValue(),
2538 printer <<
" " << getTarget() <<
" after ";
2539 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
2540 if (staticChunkSize != ShapedType::kDynamic)
2541 printer << staticChunkSize;
2543 printer << getDynamicChunkSizes();
2546 {getStaticChunkSizesAttrName()});
2547 printer <<
" : " << getTarget().getType();
2548 if (staticChunkSize == ShapedType::kDynamic)
2549 printer <<
", " << getDynamicChunkSizes().getType();
2553 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2554 (getDynamicChunkSizes() ==
nullptr)) {
2555 return emitOpError() <<
"expects either a dynamic or a static split "
2556 "point to be provided";
2565 void transform::SplitReductionOp::build(
2567 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
2568 bool useScalingAlgorithm,
bool useAlloc) {
2574 SplitReductionOp::getInsertSplitDimensionAttrName(result.
name),
2576 if (innerParallel) {
2577 result.
addAttribute(SplitReductionOp::getInnerParallelAttrName(result.
name),
2580 if (useScalingAlgorithm) {
2582 SplitReductionOp::getUseScalingAlgorithmAttrName(result.
name),
2590 result.
addTypes({resultType, resultType, resultType, resultType});
2599 unsigned(getInsertSplitDimension()),
2600 bool(getInnerParallel())};
2603 FailureOr<SplitReductionResult> splitResult =
2604 (getUseScalingAlgorithm())
2607 if (failed(splitResult))
2608 return emitDefaultDefiniteFailure(target);
2610 results.
push_back(splitResult->initOrAlloc);
2612 results.
push_back(splitResult->splitLinalgOp);
2613 results.
push_back(splitResult->resultCombiningLinalgOp);
2621 void transform::TileReductionUsingForOp::build(
2632 build(builder, result,
2635 staticTileSizesAttr);
2644 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2645 if (!partialReductionOp) {
2648 "Operation should implement PartialReductionOpInterface");
2651 rewriter, partialReductionOp,
2655 return emitDefaultSilenceableFailure(target);
2656 rewriter.
replaceOp(target, result->mergeResult.replacements);
2657 for (
Value initValue : result->initialValues)
2659 for (
auto parallelTiledOp : result->tiledOps)
2661 for (
auto mergeOp : result->mergeResult.mergeOps)
2663 results.
push_back(result->loops.front());
2671 void transform::TileReductionUsingForallOp::build(
2674 ArrayAttr mapping) {
2684 build(builder, result,
2687 staticNumThreadsAttr,
2688 staticTileSizesAttr,
2701 FailureOr<linalg::ForallReductionTilingResult> result =
2703 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2704 numThreads, tileSizes, getMapping());
2706 if (failed(result)) {
2707 auto diag = emitSilenceableError() <<
"could not tile reduction";
2708 diag.attachNote(target.getLoc()) <<
"target operation";
2711 for (
Value initValue : result->initialValues)
2713 for (
auto parallelTiledOp : result->parallelTiledOps)
2715 for (
auto mergeOp : result->mergeOps)
2731 llvm::to_vector(state.getPayloadOps(getTarget()));
2733 if (!llvm::hasSingleElement(targetOps)) {
2735 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
2740 auto linalgOp = dyn_cast<LinalgOp>(target);
2741 auto tileableOp = dyn_cast<TilingInterface>(target);
2746 OpBuilder builder(linalgOp.getContext());
2748 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
2749 if (linalgOp.hasDynamicShape()) {
2750 auto diag = emitSilenceableError()
2751 <<
"cannot compute parametric tile sizes for dynamically "
2752 "shaped payload op";
2753 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
2757 FailureOr<StaticContinuousTileSizeSpecification> spec =
2761 return emitSilenceableError()
2762 <<
"failed to compute multi-size tiling sizes";
2767 for (
auto &&[tileSize, tripCount] :
2768 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2769 chunkSizes.push_back(tileSize * tripCount);
2772 return llvm::map_to_vector(values, [&](int64_t value) ->
Attribute {
2776 transformResults.
setParams(cast<OpResult>(getTileSizes()),
2777 getI64AttrsFromI64(spec->tileSizes));
2778 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
2779 getI64AttrsFromI64(chunkSizes));
2787 unsigned dimension = getDimension();
2790 builder, tileableOp, dimension, targetSize,
true);
2792 return emitSilenceableError() <<
"could not generate tile size computation";
2804 for (
auto &&[tileSize, tripCount] :
2805 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2806 splitPoint = apply(s0 * s1, {tileSize, tripCount});
2807 chunkSizes.push_back(splitPoint);
2811 return llvm::map_to_vector(values, [&](
Value value) ->
Operation * {
2816 transformResults.
set(cast<OpResult>(getTileSizes()),
2817 getDefiningOps(spec->tileSizes));
2818 transformResults.
set(cast<OpResult>(getChunkSizes()),
2819 getDefiningOps(chunkSizes));
2827 return emitOpError() <<
"expects all results type to be the same";
2833 void transform::ContinuousTileSizesOp::getEffects(
2835 if (isa<TransformParamTypeInterface>(getTileSizes().
getType()))
2851 Type &tileSizesType,
2852 Type &chunkSizesType) {
2853 FunctionType funcType;
2855 if (failed(parser.
parseType<FunctionType>(funcType)))
2858 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
2859 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
2860 "argument and one result";
2862 targetType = funcType.getInput(0);
2863 tileSizesType = chunkSizesType = funcType.getResult(0);
2872 void transform::TileUsingForOp::build(
2877 return build(builder, result, loopTypes,
2881 interchange, scalableSizes);
2884 void transform::TileUsingForOp::build(
2888 build(builder, result, target,
2890 interchange, scalableSizes);
2893 void transform::TileUsingForOp::build(
2900 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2904 void transform::TileUsingForOp::build(
2916 unsigned numExpectedLoops =
2917 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2919 resultTypes.reserve(numExpectedLoops);
2920 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2921 "expected one loop type or as many as loops");
2922 if (loopTypes.size() == 1)
2923 resultTypes.append(numExpectedLoops, loopTypes[0]);
2925 llvm::append_range(resultTypes, loopTypes);
2927 if (scalableSizes.has_value())
2928 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2929 build(builder, result, target.
getType(),
2933 staticTileSizesAttr,
2935 expandedScalableSizes);
2940 return emitOpError(
"expected same number of sizes (")
2942 << getScalableSizes().size() <<
")";
2944 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
2945 if (getLoops().size() != numExpectedLoops)
2946 return emitOpError(
"expected number of loops to tile (")
2947 << numExpectedLoops <<
") to match number of `loops` results ("
2948 << getLoops().size() <<
")";
2959 llvm::to_vector(state.getPayloadOps(getTarget()));
2965 if (isa<ParamType>(transformValue.getType())) {
2966 dynamicSizeProducers.push_back({});
2968 paramSizes.push_back(
2969 llvm::to_vector(llvm::map_range(params, [](
Attribute attr) {
2970 return cast<IntegerAttr>(attr).getValue().getSExtValue();
2973 if (paramSizes.back().size() != targets.size()) {
2975 emitSilenceableError()
2976 <<
"expected as many parameter values ("
2977 << dynamicSizeProducers.back().size() <<
") as target ops ("
2978 << targets.size() <<
")";
2979 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
2985 paramSizes.push_back({});
2986 dynamicSizeProducers.push_back(
2987 llvm::to_vector(state.getPayloadOps(transformValue)));
2989 if (dynamicSizeProducers.back().size() != targets.size()) {
2991 emitSilenceableError()
2992 <<
"expected as many dynamic size-producing operations ("
2993 << dynamicSizeProducers.back().size() <<
") as target ops ("
2994 << targets.size() <<
")";
2995 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
2999 for (
Operation *op : dynamicSizeProducers.back()) {
3006 emitSilenceableError() <<
"expected sizes to be produced by ops "
3007 "with a single index-type result";
3008 diag.attachNote(op->
getLoc()) <<
"size producer op";
3009 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3016 loops.resize(getLoops().size());
3017 auto scalableSizes = getScalableSizes();
3019 auto tilingInterface = dyn_cast<TilingInterface>(op);
3020 if (!tilingInterface) {
3022 emitSilenceableError()
3023 <<
"only ops implementing TilingInterface are supported";
3024 diag.attachNote(op->
getLoc()) <<
"target op";
3027 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3029 emitSilenceableError()
3030 <<
"too many tiles provided, expected at most "
3031 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
3032 << tileSizes.size();
3033 diag.attachNote(op->
getLoc()) <<
"target op";
3038 if (tileSizes.empty()) {
3047 sizes.reserve(tileSizes.size());
3048 unsigned dynamicIdx = 0;
3051 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3052 if (scalableSizes[ofrIdx]) {
3053 auto val = b.
create<arith::ConstantIndexOp>(
3054 getLoc(), cast<IntegerAttr>(attr).getInt());
3060 sizes.push_back(attr);
3067 assert((dynamicSizes.empty() ^ params.empty()) &&
3068 "expected either dynamic sizes or parameters");
3069 if (!params.empty()) {
3072 sizes.push_back(dynamicSizes[index]->getResult(0));
3080 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3081 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3082 if (failed(maybeTilingResult))
3085 rewriter.
replaceOp(op, maybeTilingResult->mergeResult.replacements);
3087 tiled.append(maybeTilingResult->tiledOps);
3089 loops[en2.index()].push_back(en2.value());
3092 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3094 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3103 results.reserve(tileSizes.size());
3104 unsigned dynamicPos = 0;
3106 for (int64_t size : tileSizes) {
3107 if (size == ShapedType::kDynamic) {
3108 results.push_back(dynamic[dynamicPos++]);
3116 void transform::TileUsingForOp::getEffects(
3128 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3132 ArrayAttr mapping) {
3133 return build(builder, result,
3141 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3145 ArrayAttr mapping) {
3155 build(builder, result,
3156 TypeRange{operationType, operationType},
3163 staticTileSizesAttr,
3167 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3171 ArrayAttr mapping) {
3172 return build(builder, result, target,
3177 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3181 ArrayAttr mapping) {
3192 build(builder, result,
3193 TypeRange{operationType, operationType},
3199 staticNumThreadsAttr,
3212 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3214 for (
auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3216 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3217 normalizedUbs.push_back(normalizedUb);
3219 return normalizedUbs;
3235 for (
auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3238 denormalizedIvs.push_back(
3241 return denormalizedIvs;
3252 scf::ForallOp loop) {
3272 auto normalizedForallOp = rewriter.
create<scf::ForallOp>(
3273 loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3276 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3278 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3283 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3284 normalizedForallOp.getRegionIterArgs().end());
3285 Block *origLoopBlock = loop.getBody();
3286 rewriter.
mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3288 rewriter.
replaceOp(loop, normalizedForallOp);
3289 return normalizedForallOp;
3294 TransformOpInterface transformOp,
Operation *target,
3299 auto tileableOp = dyn_cast<TilingInterface>(target);
3302 transformOp.emitSilenceableError()
3303 <<
"only TilingInterface ops are supported";
3304 diag.attachNote(target->
getLoc()) <<
"target op";
3310 if (!mixedNumThreads.empty()) {
3311 options.setNumThreads(mixedNumThreads);
3313 options.setTileSizes(mixedTileSizes);
3316 options.setMapping(mapping.value().getValue());
3318 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3321 if (failed(maybeTilingResult))
3322 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3324 rewriter.
replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
3326 tilingResult = *maybeTilingResult;
3328 if (mixedNumThreads.empty()) {
3329 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.
loops.front());
3332 scf::ForallOp normalizedForallOp =
3334 tilingResult.
loops.front() = normalizedForallOp;
3344 auto transformOp = cast<TransformOpInterface>(getOperation());
3353 getPackedNumThreads()
3355 state, transformOp, mixedNumThreads, getPackedNumThreads())
3357 state, transformOp, mixedNumThreads, getMixedNumThreads());
3361 status = getPackedTileSizes()
3363 state, transformOp, mixedTileSizes, getPackedTileSizes())
3365 state, transformOp, mixedTileSizes, getMixedTileSizes());
3369 for (
Operation *target : state.getPayloadOps(getTarget())) {
3372 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3373 getMapping(), tilingResult);
3374 if (!
diag.succeeded())
3376 tileOps.push_back(tilingResult.
loops.front());
3377 tiledOps.append(tilingResult.
tiledOps);
3380 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3381 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3386 void transform::TileUsingForallOp::getEffects(
3399 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3408 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3409 static_cast<int>(getPackedNumThreads() !=
Value());
3410 if (numThreadsSpec > 1)
3412 "num_threads and packed_num_threads are mutually exclusive");
3413 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3414 static_cast<int>(getPackedTileSizes() !=
Value());
3415 if (tileSizesSpec > 1)
3417 "tile_sizes and packed_tile_sizes are mutually exclusive");
3418 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3419 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
3420 "must be specified");
3428 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3430 bool vectorizePadding,
bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
3432 if (vectorizePadding) {
3434 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3438 if (vectorizeExtract) {
3440 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3444 if (flatten1DDepthwiseConv) {
3446 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3457 explicit VectorizationPattern(
MLIRContext *context,
3458 bool vectorizeExtract =
false,
3459 bool flattenConv =
false)
3461 vectorizeNDExtract(vectorizeExtract),
3462 flatten1DDepthwiseConv(flattenConv) {}
3463 LogicalResult matchAndRewrite(
Operation *op,
3467 "Unsupported Op, cannot vectorize");
3469 {}, vectorizeNDExtract,
3470 flatten1DDepthwiseConv);
3476 bool vectorizeNDExtract =
false;
3480 bool flatten1DDepthwiseConv =
false;
3485 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3490 auto diag = this->emitOpError(
"requires isolated-from-above targets");
3491 diag.attachNote(target->
getLoc()) <<
"non-isolated target";
3497 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3498 getFlatten_1dDepthwiseConv());
3500 if (!getDisableTransferPermutationMapLoweringPatterns())
3503 if (!getDisableMultiReductionToContractPatterns())
3511 vector::TransferReadOp::getCanonicalizationPatterns(
patterns, ctx);
3512 vector::TransferWriteOp::getCanonicalizationPatterns(
patterns, ctx);
3517 if (getVectorizePadding()) {
3527 config.listener = &listener;
3529 return emitDefaultDefiniteFailure(target);
3543 auto targets = state.getPayloadOps(getTarget());
3544 if (std::empty(targets))
3546 auto transformOp = cast<TransformOpInterface>(getOperation());
3549 state, transformOp, getMixedVectorSizes(), vectorSizes);
3557 <<
"Unsupported Op, cannot vectorize";
3562 getVectorizeNdExtract().value_or(
false)))) {
3564 <<
"Attempted to vectorize, but failed";
3571 void transform::VectorizeOp::getEffects(
3580 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3584 if (getStaticVectorSizes().size() != getScalableSizes().size())
3585 return emitOpError(
"expected same number of vector sizes (")
3586 << getStaticVectorSizes().size() <<
") and scalable sizes ("
3587 << getScalableSizes().size() <<
")";
3596 transform::HoistRedundantVectorTransfersOp::applyToOne(
3613 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3632 auto maybeTransformed =
3635 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3638 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3641 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3644 .Case([&](linalg::Conv2DNchwFchwOp op) {
3650 if (failed(maybeTransformed))
3651 return emitDefaultSilenceableFailure(target);
3653 results.
push_back(maybeTransformed->first);
3655 results.
push_back(maybeTransformed->second);
3670 <<
"only elementwise flattening is supported";
3673 if (target.getNumLoops() <= 1) {
3680 std::iota(reassociation.begin(), reassociation.end(), 0);
3681 auto maybeFlattened =
3683 if (failed(maybeFlattened))
3685 <<
"attempted to flatten, but failed";
3686 results.
push_back(maybeFlattened->collapsedOp);
3687 rewriter.
replaceOp(target, maybeFlattened->results);
3700 auto maybeTransformed =
3702 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3705 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3711 if (failed(maybeTransformed))
3712 return emitDefaultSilenceableFailure(target);
3727 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3728 auto maybeTransformed =
3730 .Case([&](linalg::MatmulOp op) {
3733 .Case([&](linalg::BatchMatmulOp op) {
3736 .Default([&](
Operation *op) {
return failure(); });
3737 if (failed(maybeTransformed))
3747 template <
typename OpTy>
3751 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3752 tensor::ParallelInsertSliceOp>() &&
3755 if (
auto copySource =
3756 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3763 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3765 target->template getParentOfType<scf::InParallelOp>());
3768 Value extracted = rewriter.
create<tensor::ExtractSliceOp>(
3769 target.getLoc(), target.getDest(), target.getMixedOffsets(),
3770 target.getMixedSizes(), target.getMixedStrides());
3771 Value copied = rewriter
3772 .
create<linalg::CopyOp>(target.getLoc(),
3773 target.getSource(), extracted)
3778 target, copied, target.getDest(), target.getMixedOffsets(),
3779 target.getMixedSizes(), target.getMixedStrides());
3781 results.
push_back(copied.getDefiningOp());
3791 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3792 return doit(rewriter, target, results, state);
3793 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3794 return doit(rewriter, target, results, state);
3797 emitSilenceableError()
3798 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3799 diag.attachNote(targetOp->
getLoc()) <<
"target op";
3812 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3814 emitSilenceableError()
3815 <<
"only linalg.copy and tensor.pad target ops are supported";
3816 diag.attachNote(target->
getLoc()) <<
"target op";
3819 assert(target->
getNumResults() == 1 &&
"expected single result");
3821 if (!resultShapedType.hasStaticShape()) {
3823 emitSilenceableError()
3824 <<
"only statically sized ops of rank <= 3 are supported";
3825 diag.attachNote(target->
getLoc()) <<
"target op";
3830 int64_t desiredBitAlignment = getDesiredBitAlignment();
3831 int64_t eltBitwidth =
3832 resultShapedType.getElementType().getIntOrFloatBitWidth();
3833 if (desiredBitAlignment % eltBitwidth != 0) {
3834 desiredBitAlignment = eltBitwidth;
3839 getTotalNumThreads(),
3840 desiredBitAlignment,
3841 resultShapedType.getShape(),
3844 resultShapedType.getElementType().getIntOrFloatBitWidth());
3845 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3847 emitSilenceableError()
3848 <<
"too few threads to map copy op to threads on the most minor "
3849 "dimension, given alignment and vector size constraints, try "
3850 "smaller tile size of mapping to more threads";
3851 diag.attachNote(target->
getLoc()) <<
"target op";
3867 if (!
diag.succeeded())
3871 for (
auto op : tilingResult.
tiledOps)
3885 FailureOr<Operation *> maybeTransformed = failure();
3887 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3892 .Default([&](
Operation *op) {
return false; });
3895 return emitSilenceableError()
3896 <<
"this operation is not supported to convert to Winograd Conv2D";
3899 if (failed(maybeTransformed)) {
3900 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
3912 FailureOr<Operation *> maybeTransformed = failure();
3915 .Case([&](linalg::WinogradFilterTransformOp op) {
3919 .Case([&](linalg::WinogradInputTransformOp op) {
3923 .Case([&](linalg::WinogradOutputTransformOp op) {
3927 .Default([&](
Operation *op) {
return false; });
3931 emitSilenceableError()
3932 <<
"this operation is not supported to decompose into other operations";
3933 diag.attachNote(target->
getLoc()) <<
"target op";
3937 if (failed(maybeTransformed)) {
3939 emitSilenceableError() <<
"decompose Winograd operations failed";
3940 diag.attachNote(target->
getLoc()) <<
"target op";
3948 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3950 #define GET_OP_CLASSES
3951 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
This class represents a saved insertion point.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
uint64_t getM(LevelType lt)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
@ MaterializeInDestination
Transformation to drop unit-extent dimensions from linalg.generic operations.
Vectorization pattern for memref::CopyOp.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Match and rewrite for the pattern:
Match and rewrite for the pattern:
@ BufferizationMaterializeInDestination
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.