44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/ScopeExit.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/Debug.h"
48 #include <type_traits>
54 #define DEBUG_TYPE "linalg-transforms"
55 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56 #define DBGSNL() (llvm::dbgs() << "\n")
57 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
64 template <
typename PatternTy,
typename... Args>
67 using OpTy =
typename llvm::function_traits<
68 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
69 auto op = dyn_cast<OpTy>(operation);
74 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
79 explicit TrivialPatternRewriter(
MLIRContext *context)
82 TrivialPatternRewriter rewriter(operation->
getContext());
83 rewriter.setInsertionPoint(operation);
84 auto result = pattern.returningMatchAndRewrite(op, rewriter);
87 return cast<LinalgOp>(result->getOperation());
97 if (
auto attr = dyn_cast<Attribute>(ofr)) {
98 if (!isa<IntegerAttr>(attr))
99 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
100 result.push_back(ofr);
104 Value transformValue = cast<Value>(ofr);
105 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
107 if (params.size() != 1)
108 return transformOp.emitDefiniteFailure()
109 <<
"requires exactly one parameter associated";
110 result.push_back(params[0]);
114 auto payloadOps = state.getPayloadOps(transformValue);
115 if (!llvm::hasSingleElement(payloadOps)) {
117 transformOp.emitSilenceableError()
118 <<
"handle must be mapped to exactly one payload op";
120 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
127 transformOp.emitSilenceableError()
128 <<
"payload op must have exactly 1 index result";
148 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
150 for (
auto param : params) {
151 if (!isa<IntegerAttr>(param))
152 return transformOp.emitDefiniteFailure()
153 <<
"expected the parameter to be associated with an integer "
155 result.push_back(param);
160 for (
Operation *op : state.getPayloadOps(packedHandle)) {
161 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
163 transformOp.emitSilenceableError()
164 <<
"payload op must have exactly 1 index result";
165 diag.attachNote(op->getLoc())
166 <<
"has " << op->getNumResults() <<
" results";
169 result.push_back(op->getResult(0));
183 if (
auto attr = dyn_cast<Attribute>(paramOrHandle)) {
184 reified.push_back(cast<IntegerAttr>(attr).getInt());
186 }
else if (isa<ParamType>(cast<Value>(paramOrHandle).
getType())) {
188 if (params.size() != 1)
189 return transformOp.emitSilenceableError() <<
"expected a single param";
191 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
195 Value handle = cast<Value>(paramOrHandle);
196 if (!isa<TransformHandleTypeInterface>(handle.
getType()))
197 return transformOp.emitSilenceableError() <<
"unexpected value handle";
198 auto payload = state.getPayloadOps(handle);
199 if (!llvm::hasSingleElement(payload))
200 return transformOp.emitSilenceableError()
201 <<
"requires param or handle that is mapped to 1 payload op";
203 Operation *paramOrHandlePayloadOp = *payload.begin();
206 return transformOp.emitSilenceableError()
207 <<
"requires param or handle to be result of op with 1 index "
213 return transformOp.emitSilenceableError()
214 <<
"requires param or handle to be the result of a constant like "
217 reified.push_back(attr.getInt());
226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
231 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
236 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
241 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
247 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
250 options.rankReductionStrategy =
255 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
260 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
265 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
270 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
275 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
284 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
289 resultTypes.push_back(b.
getType<transform::AnyValueType>());
290 resultTypes.push_back(b.
getType<transform::AnyOpType>());
291 return build(b, result,
297 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
300 int64_t memorySpace) {
302 resultTypes.push_back(b.
getType<transform::AnyValueType>());
303 resultTypes.push_back(b.
getType<transform::AnyOpType>());
304 return build(b, result,
320 void notifyOperationInserted(
Operation *op,
322 ForwardingListener::notifyOperationInserted(op, previous);
324 if (previous.
isSet())
326 auto inserted = newOps.insert(op);
328 assert(inserted.second &&
"expected newly created op");
331 void notifyOperationErased(
Operation *op)
override {
332 ForwardingListener::notifyOperationErased(op);
346 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
347 NewOpsListener newOpsListener(previousListener);
351 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
354 }
else if (getMemcpyOp() ==
"memref.copy") {
357 }
else if (getMemcpyOp() ==
"linalg.copy") {
361 llvm_unreachable(
"invalid memcpy op");
363 if (getAllocOp() ==
"memref.alloc") {
366 }
else if (getAllocOp() ==
"memref.alloca") {
370 llvm_unreachable(
"invalid alloc op");
372 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
373 options.emitDealloc = getEmitDealloc();
377 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
379 for (
Operation *op : state.getPayloadOps(getTarget())) {
384 <<
"failed to bufferize operation";
385 diag.attachNote(op->
getLoc()) <<
"target payload op";
388 allocatedBuffers.push_back(buffer);
392 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
393 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
397 void transform::BufferizeToAllocationOp::getEffects(
399 if (getBufferizeDestinationOnly()) {
411 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
412 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
413 return emitOpError() <<
"unsupported memcpy op";
414 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
415 return emitOpError() <<
"unsupported alloc op";
428 #define DOWNSCALE(trans) \
430 FailureOr<LinalgOp> res = tryApply<trans>(target); \
431 if (succeeded(res)) { \
432 results.push_back(*res); \
433 return DiagnosedSilenceableFailure::success(); \
437 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
438 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
451 #undef DOWNSCALE_NORMAL
452 #undef DOWNSCALE_CALL
454 return emitDefaultSilenceableFailure(target);
468 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
469 if (!decomposableOp) {
471 "payload is not a decomposable op"));
472 return emitDefaultSilenceableFailure(target);
475 FailureOr<SmallVector<Value>> maybeNewResults =
476 decomposableOp.decomposeOperation(rewriter);
477 if (failed(maybeNewResults))
478 return emitDefaultSilenceableFailure(target);
480 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
481 for (
Value val : *maybeNewResults) {
482 Operation *definition = val.getDefiningOp();
493 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
500 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
504 options.allowReturnAllocsFromLoops =
true;
506 for (
Operation *target : state.getPayloadOps(getTarget())) {
510 <<
"failed to analyze op";
512 rewriter, target, state)))
514 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
525 template <
typename Range>
529 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
535 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
536 if (!tilingInterfaceOp)
537 return transformOp->
emitError(
"only TilingInterface ops are supported");
540 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
541 applyFn(tilingInterfaceOp);
542 if (failed(tiledResults))
547 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
548 for (
Operation *toReplace : opsToReplace) {
549 for (
OpResult res : toReplace->getResults())
550 if (
auto replacement = tiledResults->replacements.lookup(res))
552 if (toReplace->use_empty()) {
558 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
559 assert(tiledResults->loops.size() == numLoops &&
560 "Mismatched number of loops, tile and fuse transform should have "
562 for (
unsigned int i = 0; i < numLoops; ++i)
563 loopOps[i].push_back(tiledResults->loops[i]);
566 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
567 for (
unsigned int i = 0; i < numLoops; ++i)
568 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
578 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
580 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
586 tilingOptions = tilingOptions.
setTileSizes(tileSizesOfr);
590 if (getApplyCleanup()) {
593 tensor::ExtractSliceOp::getCanonicalizationPatterns(
patterns, context);
600 rewriter, getOperation(), state.getPayloadOps(getTarget()),
601 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
602 [&](TilingInterface tilingInterfaceOp)
603 -> FailureOr<scf::SCFTileAndFuseResult> {
604 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
613 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
614 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
615 if (!std::is_permutation(sequence.begin(), sequence.end(),
616 permutation.begin(), permutation.end())) {
617 return emitOpError() <<
"expects interchange to be a permutation, found "
618 << getTileInterchange();
622 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
623 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
624 if (numExpectedLoops != getNumResults() - 1)
625 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
634 void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
637 Value containingOp) {
640 result.
addTypes({resultType, resultType});
656 (domInfo.
dominates(containingOp, user))) {
657 dominatedUsers.insert(user);
660 if (dominatedUsers.empty())
664 auto forallOp = cast<scf::ForallOp>(containingOp);
670 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
675 newOuts.push_back(outputs[resultNumber]);
678 auto newforallOp = rewriter.
create<scf::ForallOp>(
679 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
680 forallOp.getMixedStep(), newOuts, forallOp.getMapping());
682 newforallOp.getRegion().takeBody(forallOp.getRegion());
687 newforallOp.getBody()->addArgument(newOuts.back().getType(),
688 newOuts.back().getLoc());
689 auto bbArgs = newforallOp.getBody()->getArguments();
697 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
699 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; }));
700 Operation *firstYieldOp = yieldingOps.front();
703 Value dst = newforallOp.getRegionIterArgs().back();
705 rewriter.
create<tensor::ParallelInsertSliceOp>(firstYieldOp->
getLoc(), src,
706 dst, offsets, sizes, strides);
710 newforallOp->getResult(result.index()));
713 newforallOp->getResults().back(),
716 return dominatedUsers.contains(user);
727 static std::tuple<SmallVector<Operation *>,
Operation *>
730 LLVM_DEBUG(
DBGS() <<
"Try to fuse a direct extract use\n");
731 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
732 if (!tileableProducer) {
734 <<
"producer is not a TileableInterface: " << *producerOp;
741 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
742 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
743 return sliceOp && containingOp->isProperAncestor(sliceOp);
747 if (it == tileableProducer->getUsers().end()) {
748 diag.attachNote(tileableProducer->getLoc())
749 <<
"could not find fusion opportunity for: " << *tileableProducer;
752 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
759 int64_t resultNumber =
760 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
761 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
766 FailureOr<TilingResult> tileAndFuseResult =
767 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
770 if (failed(tileAndFuseResult)) {
771 diag.attachNote(tileableProducer->getLoc())
772 <<
"failed to tile producer op: " << *tileableProducer;
777 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
778 LLVM_DEBUG(
DBGS() <<
"tiledProducer: " << *tiledOp <<
"\n");
783 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
784 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
785 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
786 if (failed(maybeRankReduced)) {
788 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
789 << tileAndFuseResult->tiledValues[0]
790 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
793 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
797 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
798 resultNumber, offsets, sizes);
800 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
813 LLVM_DEBUG(
DBGS() <<
"Try to fuse an extract use through block argument\n");
815 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
816 if (!tileableProducer) {
818 <<
"producer is not a TileableInterface: " << *producerOp;
823 scf::ForallOp forallOp;
824 auto itProducerUses =
825 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
826 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
830 if (!forallOp || forallOp != containingOp) {
831 diag.attachNote(tileableProducer->getLoc())
832 <<
"could not find a use by the containing op: " << *tileableProducer;
847 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
848 return sliceOp && containingOp->isProperAncestor(sliceOp);
852 if (itBBArgUsers == bbArg.
getUsers().end()) {
854 <<
"could not find fusion opportunity for bbArg: " << bbArg;
857 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
865 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
866 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
871 rewriter, tileableProducer->getLoc(), tileableProducer,
872 destinationTensors))) {
873 diag.attachNote(tileableProducer->getLoc())
874 <<
"failed to get destination tensors for: " << *tileableProducer;
879 bvm.
map(destinationTensors[resultNumber], bbArg);
880 auto tileableProducerClone =
881 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
883 llvm::make_scope_exit([&]() { rewriter.
eraseOp(tileableProducerClone); });
886 FailureOr<TilingResult> tileAndFuseResult =
887 tileableProducerClone.generateResultTileValue(
888 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
889 sliceOpToTile.getMixedSizes());
890 if (failed(tileAndFuseResult)) {
891 diag.attachNote(tileableProducer->getLoc())
892 <<
"failed to tile producer op: " << *tileableProducer;
897 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
898 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
899 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
900 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
901 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
906 destinationTensors.front());
909 return tileAndFuseResult->tiledOps;
915 LLVM_DEBUG(
DBGS() <<
"Try to fuse an use by cloning\n");
920 for (
OpOperand &use : result.getUses()) {
922 uses.push_back(&use);
927 if (containingOp == use.getOwner()) {
929 <<
"producer op use by containing op cannot be fused by cloning";
937 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
946 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
947 "Parallel insert slice is not a valid clone destination");
948 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
949 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
953 fusedOp = rewriter.
clone(*producerOp);
955 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
960 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
970 auto producerOps = state.getPayloadOps(getProducerOp());
971 auto containingOps = state.getPayloadOps(getContainingOp());
972 if (!llvm::hasSingleElement(containingOps)) {
974 <<
"requires exactly one containing_op handle (got "
975 << llvm::range_size(containingOps) <<
")";
977 Operation *containingOp = *containingOps.begin();
980 if (std::empty(producerOps)) {
982 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
990 auto getNextProducer = [&]() -> FailureOr<Operation *> {
991 for (
const auto &it :
enumerate(remainingProducers)) {
994 int64_t numUsesInContainingOp =
996 return containingOp->isAncestor(op);
1001 if (numUsesInContainingOp > 0) {
1002 if (numUsesInContainingOp == 1)
1003 remainingProducers.erase(remainingProducers.begin() + it.index());
1010 while (!remainingProducers.empty()) {
1011 auto nextProducer = getNextProducer();
1012 if (failed(nextProducer)) {
1014 <<
"could not find next producer to fuse into container";
1015 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
1023 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
1030 auto [tiledOps, newContainingOp] =
1032 if (!tiledOps.empty()) {
1033 LLVM_DEBUG(
DBGS() <<
"\nFused a direct extract use\n" << *containingOp);
1034 fusedOps.append(tiledOps);
1035 if (newContainingOp) {
1043 LogicalResult replacementStatus =
1046 (void)replacementStatus;
1047 assert(succeeded(replacementStatus) &&
1048 "unable to update transform state mapping");
1049 rewriter.
eraseOp(containingOp);
1050 containingOp = newContainingOp;
1057 rewriter,
diag, producerOp, containingOp);
1058 if (!tiledContainingOpOperand.empty()) {
1059 LLVM_DEBUG(
DBGS() <<
"\nFused an extract use through block argument\n"
1061 fusedOps.append(tiledContainingOpOperand);
1068 LLVM_DEBUG(
DBGS() <<
"\nFused an use by cloning\n" << *containingOp);
1069 fusedOps.push_back(cloned);
1075 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1076 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1080 void transform::FuseIntoContainingOp::getEffects(
1098 if (isa<GenericOp>(target)) {
1104 if (succeeded(
generic)) {
1105 results.
push_back(generic->getOperation());
1108 return emitDefaultSilenceableFailure(target);
1121 if (!isa<GenericOp>(target)) {
1126 FailureOr<LinalgOp> named =
1128 if (succeeded(named)) {
1129 results.
push_back(named->getOperation());
1132 return emitDefaultSilenceableFailure(target);
1146 if (interchangeVector.empty()) {
1151 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1152 if (interchangeVector.size() != numLoops) {
1153 return emitSilenceableError()
1154 << getIteratorInterchangeAttrName() <<
" has length ("
1155 << interchangeVector.size()
1156 <<
") different from the number of loops in the target operation ("
1169 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1170 if (!std::is_permutation(sequence.begin(), sequence.end(),
1171 permutation.begin(), permutation.end())) {
1172 return emitOpError()
1173 <<
"expects iterator_interchange to be a permutation, found "
1174 << getIteratorInterchange();
1189 if (!isa<linalg::CopyOp>(targetOp)) {
1191 emitSilenceableError() <<
"only linalg.copy target ops are supported";
1192 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1196 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1197 if (!copyOp.hasPureBufferSemantics()) {
1199 emitSilenceableError()
1200 <<
"cannot transform a linalg.copy on tensors into a memref.copy";
1201 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1207 assert(inputs.size() == 1 &&
"expected linalg copy op with one input");
1208 assert(outputs.size() == 1 &&
"expected memref copy op with one output");
1209 Value input = inputs.front();
1210 Value output = outputs.front();
1215 if (!dyn_cast<ShapedType>(input.
getType())) {
1217 emitSilenceableError()
1218 <<
"cannot transform a linalg.copy which input has no shape";
1219 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1224 assert(dyn_cast<ShapedType>(output.
getType()));
1226 if (cast<ShapedType>(input.
getType()).getElementType() !=
1227 cast<ShapedType>(output.
getType()).getElementType()) {
1229 emitSilenceableError()
1230 <<
"cannot transform a linalg.copy with different source and "
1231 "destination element types ";
1232 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1253 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1254 FailureOr<LowerPackResult> res =
1255 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1258 <<
"cannot lower to pad + expand + transpose";
1261 transformResults.
push_back(res->expandShapeOp);
1262 transformResults.
push_back(res->transposeOp);
1275 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1276 FailureOr<LowerUnPackOpResult> res =
1277 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1280 emitSilenceableError()
1281 <<
"cannot lower to transpose + collapse + extract";
1282 diag.attachNote(target->getLoc()) <<
"target payload op";
1285 transformResults.
push_back(res->emptyOp);
1286 transformResults.
push_back(res->transposeOp);
1287 transformResults.
push_back(res->collapseShapeOp);
1288 transformResults.
push_back(res->extractSliceOp);
1318 if (getOps().has_value())
1319 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1321 auto payloadOps = state.getPayloadOps(getTarget());
1322 if (!llvm::hasSingleElement(payloadOps)) {
1327 bool incorrectNumOperandTypes =
false;
1334 if (getInterface().has_value()) {
1335 auto iface = getInterface().value();
1336 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1339 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1340 !isa<TilingInterface>(op))
1342 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1343 !isa<LoopLikeOpInterface>(op))
1348 if (getOpAttrs().has_value()) {
1349 DictionaryAttr opAttrs = getOpAttrs().value();
1351 if (attr.getName() == getInterfaceAttrName() ||
1352 attr.getName() == getOpsAttrName())
1354 if (!op->
hasAttr(attr.getName()))
1356 if (op->
getAttr(attr.getName()) != attr.getValue())
1361 if (getFilterResultType().has_value()) {
1362 Type t = getFilterResultType().value();
1367 if (getFilterOperandTypes().has_value()) {
1368 mlir::ArrayAttr types = getFilterOperandTypes().value();
1371 if (types.size() == 1) {
1374 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1375 Type t = cast<::mlir::Type>(typeattr.getValue());
1377 [&](
Type operandType) { return operandType == t; }))
1382 if (types.size() != operandTypes.size()) {
1383 incorrectNumOperandTypes =
true;
1387 for (
auto [attr, operandType] :
1388 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1389 auto typeattr = cast<mlir::TypeAttr>(attr);
1390 Type type = cast<::mlir::Type>(typeattr.getValue());
1392 if (type != operandType)
1403 (*payloadOps.begin())->
walk(matchFun);
1404 if (incorrectNumOperandTypes)
1406 "type, then it must contain as much types as "
1407 "the number of operands in the target ops");
1408 results.
set(cast<OpResult>(getResult()), res);
1423 Type &targetType,
Type &lowSizeType,
1425 Type &splitPointType) {
1426 FunctionType funcType;
1428 if (failed(parser.
parseType<FunctionType>(funcType)))
1431 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1432 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1433 "argument and one result";
1435 targetType = funcType.getInput(0);
1436 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1444 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1445 if (target.hasDynamicShape()) {
1446 auto diag = emitSilenceableError()
1447 <<
"cannot compute parametric tile sizes for dynamically "
1448 "shaped payload op";
1449 diag.attachNote(target->getLoc()) <<
"payload op";
1454 target, getDimension(), getTargetSize(), getDivisor());
1456 return emitSilenceableError()
1457 <<
"failed to compute multi-size tiling sizes";
1460 Builder builder(target.getContext());
1461 results.
assign(llvm::map_range(
1463 spec->lowTileSize * spec->lowTripCount}),
1464 [&builder,
this](int64_t value) {
1476 builder, target, getDimension(), targetSize, divisor);
1478 return emitSilenceableError() <<
"could not generate tile size computation";
1485 {spec->lowTileSize, spec->lowTripCount});
1486 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1487 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1488 assert(lowTileSize && highTileSize && splitPoint &&
1489 "tile sizes are not produced by operations");
1497 void transform::MultiTileSizesOp::getEffects(
1501 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1510 return emitOpError() <<
"expects all results type to be the same";
1530 builder.
getContext(), GenericOp::getOperationName());
1531 build(builder, result,
1540 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1547 auto targetOps = state.getPayloadOps(getTarget());
1549 if (std::empty(targetOps)) {
1550 transformResults.
set(cast<OpResult>(getPackedOp()),
1555 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1556 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1557 return emitSilenceableError()
1558 <<
"requires target to map to exactly 1 LinalgOp (got "
1559 << llvm::range_size(targetOps) <<
")";
1562 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1563 return emitSilenceableError()
1564 <<
"requires number of packed sizes match the number of loops ("
1565 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1572 state, *
this, packedSizes, getMixedPackedSizes());
1575 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1576 if (failed(maybeResult))
1579 transformResults.
set(cast<OpResult>(getPackedOp()),
1580 {maybeResult->packedLinalgOp.getOperation()});
1584 void transform::PackOp::getEffects(
1598 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1599 <<
" is not a valid permutation";
1602 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1603 for (
auto [s, nmo] :
1604 llvm::zip_equal(getMixedMatmulPackedSizes(),
1605 getMatmulPaddedSizesNextMultipleOf())) {
1608 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1609 return emitOpError() <<
"at most one of the packed_size and the "
1610 "padded_sizes_next_multiple_of can be nonzero "
1611 "for the matmul strategy";
1623 for (
Operation *op : state.getPayloadOps(getTarget())) {
1624 auto linalgOp = dyn_cast<LinalgOp>(op);
1635 getMixedMatmulPackedSizes(),
1637 getMatmulPaddedSizesNextMultipleOf(),
1638 getMatmulInnerDimsOrder());
1639 if (succeeded(packResult)) {
1640 results.push_back(packResult->packedLinalgOp);
1643 results.push_back(linalgOp);
1645 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1651 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1655 void transform::PackGreedilyOp::getEffects(
1669 return emitOpError() << getInnerPermAttrName()
1670 <<
" is not a valid permutation";
1673 return emitOpError() << getOuterPermAttrName()
1674 <<
" is not a valid permutation";
1676 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1677 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1678 <<
" or " << getOuterPermAttrName()
1679 <<
" must be specified";
1685 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1695 template <
typename RelayoutOpTy>
1698 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1700 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1701 "applies to only pack or unpack operations");
1702 if (!op || permutation.empty())
1704 size_t innerRank = op.getInnerDimsPos().size();
1705 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1709 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1710 return permutation.size() == op.getSourceRank() &&
1713 return permutation.size() == op.getDestRank() &&
1721 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1722 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1724 if (std::empty(packOrUnpackOps)) {
1725 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1726 transformResults.
set(cast<OpResult>(getPackOp()), {});
1727 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1733 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1734 !llvm::hasSingleElement(linalgOps)) {
1735 return emitSilenceableError()
1736 <<
"requires target to map to exactly 1 "
1737 "packing op and 1 packed op ("
1738 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
1739 << llvm::range_size(linalgOps) <<
")";
1743 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1744 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1745 if ((!packOp && !unPackOp)) {
1746 return emitSilenceableError() <<
"requires target to map to a "
1747 "linalg.pack or linalg.unpack";
1749 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1750 if (!linalgOpTarget)
1751 return emitSilenceableError() <<
"requires a LinalgOp target";
1755 if (packOp && packOp.getResult().hasOneUse())
1756 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1758 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1759 if (linalgOp != linalgOpTarget) {
1761 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
1762 : StringLiteral{
"not produced by the LinalgOp target"};
1763 return emitSilenceableError() << errorMsg;
1769 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
1770 OpOperand *packUse = linalgOp.getDpsInitOperand(
1771 cast<OpResult>(unPackOp.getSource()).getResultNumber());
1773 if (!packOp || !packOp.getResult().hasOneUse())
1774 return emitSilenceableError() <<
"could not find matching pack op";
1778 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1780 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1781 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1782 ? StringLiteral{
"invalid outer_perm"}
1783 : StringLiteral{
"invalid inner_perm"};
1787 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1788 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
1794 assert(packOp && linalgOp &&
"unexpected null op");
1798 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1800 assert(succeeded(res) &&
"unexpected packTranspose failure");
1803 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1804 transformResults.
set(cast<OpResult>(getPackedOp()),
1805 {res->transposedLinalgOp});
1807 transformResults.
set(cast<OpResult>(getUnPackOp()),
1808 {res->transposedUnPackOp});
1810 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1825 StringRef copyBackOp) {
1835 (padToMultipleOf.empty()
1837 : b.getDenseI64ArrayAttr(padToMultipleOf)),
1838 b.getI64ArrayAttr(nofoldFlags),
1839 b.getArrayAttr(transposePaddings),
1840 b.getStringAttr(copyBackOp));
1848 StringRef copyBackOp) {
1853 staticPadToMultipleOf);
1860 dynamicPadToMultipleOf,
1861 staticPadToMultipleOf,
1867 void PadOp::getEffects(
1877 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1884 auto transformOp = cast<TransformOpInterface>(getOperation());
1887 for (
Operation *target : state.getPayloadOps(getTarget())) {
1888 auto linalgTarget = dyn_cast<LinalgOp>(target);
1889 if (!linalgTarget) {
1890 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
1891 diag.attachNote(target->
getLoc()) <<
"target op";
1897 for (int64_t packPadding :
1898 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1899 nofoldFlags.push_back(
static_cast<bool>(packPadding));
1903 for (
auto const &it :
1904 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1905 auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1907 emitOpError(
"expects padding values to be typed attributes");
1912 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
1916 if (!parsedAttr || parsedAttr.getType() != elementType) {
1917 auto diag = this->emitOpError(
"expects a padding that parses to ")
1918 << elementType <<
", got " << std::get<0>(it);
1919 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
1922 paddingValues.push_back(parsedAttr);
1926 if (attr.getType() != elementType) {
1927 auto diag = this->emitOpError(
"expects a padding value of type ")
1928 << elementType <<
", got " << attr;
1929 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
1932 paddingValues.push_back(attr);
1937 for (
Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1938 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1939 cast<ArrayAttr>(transposeVector)));
1944 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1948 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
1951 if (padToMultipleOf.empty())
1955 options.padToMultipleOf = padToMultipleOf;
1956 options.paddingValues = paddingValues;
1957 options.nofoldFlags = nofoldFlags;
1958 if (getCopyBackOp() ==
1959 bufferization::MaterializeInDestinationOp::getOperationName()) {
1962 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1964 }
else if (getCopyBackOp() == kCopyOpNone) {
1967 llvm_unreachable(
"unsupported copy_back op");
1973 replacements, newPadOps))) {
1974 auto diag = emitSilenceableError() <<
"failed to pad op";
1975 diag.attachNote(target->
getLoc()) <<
"target op";
1984 rewriter.
replaceOp(linalgTarget, replacements);
1985 paddedOps.push_back(paddedOp);
1986 padOps.append(newPadOps.begin(), newPadOps.end());
1988 for (
Value v : replacements) {
1989 Operation *copyBackOp = v.getDefiningOp();
1990 if (!llvm::is_contained(copyBackOps, copyBackOp))
1991 copyBackOps.push_back(copyBackOp);
1996 results.
set(cast<OpResult>(getPadded()), paddedOps);
1997 results.
set(cast<OpResult>(getPad()), padOps);
1998 results.
set(cast<OpResult>(getCopy()), copyBackOps);
2004 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2005 if (any_of(nofoldFlags, [](int64_t packPadding) {
2006 return packPadding != 0 && packPadding != 1;
2008 return emitOpError()
2009 <<
"expects nofold_flags to contain booleans (0/1), found "
2010 << getNofoldFlags();
2014 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2015 if (any_of(paddingDimensions,
2016 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
2017 return emitOpError() <<
"expects padding_dimensions to contain positive "
2019 << getPaddingDimensions();
2021 if (!getMixedPadToMultipleOf().empty()) {
2022 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2023 return emitOpError() <<
"expects as many multiples as padding_dimensions";
2026 ArrayAttr transposes = getTransposePaddings();
2029 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2030 if (!std::is_permutation(sequence.begin(), sequence.end(),
2032 return emitOpError()
2033 <<
"expects transpose_paddings to be a permutation, found "
2037 if (getCopyBackOp() !=
2038 bufferization::MaterializeInDestinationOp::getOperationName() &&
2039 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2040 getCopyBackOp() != kCopyOpNone)
2041 return emitOpError() <<
"invalid copy_back_op";
2053 auto targetOps = state.getPayloadOps(getTarget());
2054 auto loopOps = state.getPayloadOps(getLoop());
2055 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2057 <<
"requires exactly one target and one loop handle (got "
2058 << llvm::range_size(targetOps) <<
" and "
2059 << llvm::range_size(loopOps) <<
")";
2062 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2063 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2064 if (!padOp || !loopOp)
2067 FailureOr<linalg::detail::PackingResult> result =
2073 if (result->clonedLoopIvs.empty()) {
2074 transformResults.
set(cast<OpResult>(getPackingLoop()),
2075 {result->hoistedPadOp.getOperation()});
2078 auto outerPackedLoop =
2080 transformResults.
set(cast<OpResult>(getPackingLoop()),
2081 {outerPackedLoop.getOperation()});
2087 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2088 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
2090 return emitOpError() <<
"expects transpose to be a permutation, found "
2096 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2106 tensor::PadOp target,
2109 tensor::PadOp hoistedPadOp;
2111 FailureOr<Value> result =
2113 hoistedPadOp, transposeOps);
2114 if (succeeded(result)) {
2124 return emitDefaultSilenceableFailure(target);
2129 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2130 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
2132 return emitOpError() <<
"expects transpose to be a permutation, found "
2148 if (!getOperandsToPromote().empty())
2150 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2151 if (getUseFullTilesByDefault())
2153 getUseFullTilesByDefault());
2155 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2156 if (!getUseFullTileBuffers().empty())
2158 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2159 if (getAlignment().has_value())
2160 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2161 if (getMemorySpace().has_value())
2162 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2164 if (getMapping().has_value()) {
2166 auto mapping = *getMapping();
2167 if (mapping.size() > 1)
2168 return emitDefaultDefiniteFailure(target);
2170 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2172 if (addressSpace.getAddressSpace() ==
2173 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2180 }
else if (addressSpace.getAddressSpace() ==
2181 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2189 return emitDefaultDefiniteFailure(target);
2194 return emitDefaultDefiniteFailure(target);
2197 FailureOr<LinalgOp> res =
promoteSubViews(rewriter, target, promotionOptions);
2199 return emitDefaultDefiniteFailure(target);
2212 auto payload = state.getPayloadOps(getTarget());
2216 if (target->getNumOperands() > 0)
2219 target->getNumRegions() > 0)
2221 <<
"expected target that is isolated from above";
2225 Operation *pattern = &getBodyRegion().front().front();
2228 if (getOperation()->isAncestor(target))
2233 replacements.push_back(replacement);
2235 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2239 void transform::ReplaceOp::getEffects(
2247 if (!getBodyRegion().hasOneBlock())
2248 return emitOpError() <<
"expected one block";
2249 if (std::distance(getBodyRegion().front().begin(),
2250 getBodyRegion().front().end()) != 1)
2251 return emitOpError() <<
"expected one operation in block";
2252 Operation *replacement = &getBodyRegion().front().front();
2255 <<
"expected replacement without operands";
2259 <<
"expect op that is isolated from above";
2277 target.createFlatListOfOperandDims(b, loc);
2278 AffineMap map = target.getShapesToLoopsMap();
2294 FailureOr<scf::SCFTilingResult> maybeTilingResult =
tileUsingSCF(
2295 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2296 if (failed(maybeTilingResult))
2297 return emitDefaultDefiniteFailure(target);
2299 if (target->getNumResults())
2300 rewriter.
replaceOp(target, maybeTilingResult->mergeResult.replacements);
2304 results.
reserve(maybeTilingResult->tiledOps.size());
2305 for (
Operation *tiled : maybeTilingResult->tiledOps)
2319 for (
Operation *target : state.getPayloadOps(getTarget())) {
2320 auto tilingOp = dyn_cast<TilingInterface>(*target);
2323 emitSilenceableError()
2324 <<
"expected the payload to implement TilingInterface";
2325 diag.attachNote(target->getLoc()) <<
"payload op";
2329 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2331 if (failed(generatedLoops))
2332 return emitDefaultDefiniteFailure(target);
2333 for (scf::ForOp &loop : *generatedLoops) {
2334 loops.push_back(loop.getOperation());
2338 results.
set(cast<OpResult>(getResult()), loops);
2347 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2353 FailureOr<Operation *> maybeResult =
2355 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2356 [&rewriter](
auto op) {
2359 if (failed(maybeResult))
2360 return emitDefaultSilenceableFailure(target);
2374 llvm::to_vector(state.getPayloadOps(getTarget()));
2376 bool isMultiwaySplit = getMultiway();
2378 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2380 <<
"requires exactly one target when "
2381 "multiway split is enabled (got "
2382 << llvm::range_size(payload) <<
")";
2387 if (!isMultiwaySplit)
2388 chunkSizes.reserve(payload.size());
2390 if (getDynamicChunkSizes()) {
2392 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2393 chunkSizes = llvm::to_vector(llvm::map_range(
2394 state.getPayloadOps(getDynamicChunkSizes()), [&](
Operation *op) {
2397 diag = emitSilenceableError()
2398 <<
"expected dynamic split point handle to point to a "
2399 "single-result index-typed op";
2400 diag.attachNote(op->getLoc()) <<
"dynamic split point";
2405 chunkSizes = llvm::to_vector(
2406 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2409 if (
diag.isSilenceableFailure())
2414 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2416 <<
"expected the dynamic split point handle to point to as "
2418 << chunkSizes.size() <<
") as the target handle ("
2419 << payload.size() <<
")";
2422 chunkSizes.resize(payload.size(),
2426 auto checkStructuredOpAndDimensions =
2429 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2430 diag.attachNote(loc) <<
"target op";
2434 if (getDimension() >= linalgOp.getNumLoops()) {
2435 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2436 <<
" does not exist in target op";
2437 diag.attachNote(loc) <<
"target op";
2443 auto checkFailureInSplitting =
2447 diag.attachNote(loc) <<
"target op";
2454 if (isMultiwaySplit) {
2457 TilingInterface head, tail;
2460 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2464 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2465 if (
diag.isSilenceableFailure())
2471 target = tail.getOperation();
2476 linalgOp = cast<LinalgOp>(target);
2481 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2482 getDimension(), chunkSize);
2486 checkFailureInSplitting(!head && !tail, loc);
2487 if (
diag.isDefiniteFailure())
2490 opList.push_back(head.getOperation());
2495 opList.push_back(tail.getOperation());
2501 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2504 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2506 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2508 if (
diag.isSilenceableFailure())
2512 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2513 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2514 getDimension(), std::get<1>(pair));
2518 checkFailureInSplitting(!first.back() && !second.back(), loc);
2523 if (!second.back()) {
2524 noSecondPart = target;
2529 if (second.size() != first.size() && !second.empty()) {
2530 auto diag = emitSilenceableError()
2531 <<
"splitting does not produce the second part for a subset "
2534 <<
"expected splitting to produce the second part of all "
2535 "or none of the targets";
2537 <<
"first target with no second part";
2541 opList.append(first);
2543 opList.append(second);
2545 results.
set(cast<OpResult>(getSplitList()), opList);
2549 void SplitOp::getEffects(
2552 if (getDynamicChunkSizes())
2560 IntegerAttr staticChunkSizes;
2566 if (!dynamicPointParseResult.
has_value()) {
2567 int64_t staticChunkSizesValue;
2568 if (failed(parser.
parseInteger(staticChunkSizesValue)))
2581 if (dynamicPointParseResult.
has_value()) {
2582 Type ChunkSizesType;
2583 if (failed(*dynamicPointParseResult) || parser.
parseComma() ||
2595 SplitOp::getStaticChunkSizesAttrName(result.
name).getValue(),
2602 printer <<
" " << getTarget() <<
" after ";
2603 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
2604 if (staticChunkSize != ShapedType::kDynamic)
2605 printer << staticChunkSize;
2607 printer << getDynamicChunkSizes();
2610 {getStaticChunkSizesAttrName()});
2611 printer <<
" : " << getTarget().getType();
2612 if (staticChunkSize == ShapedType::kDynamic)
2613 printer <<
", " << getDynamicChunkSizes().getType();
2617 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2618 (getDynamicChunkSizes() ==
nullptr)) {
2619 return emitOpError() <<
"expects either a dynamic or a static split "
2620 "point to be provided";
2629 void transform::SplitReductionOp::build(
2631 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
2632 bool useScalingAlgorithm,
bool useAlloc) {
2638 SplitReductionOp::getInsertSplitDimensionAttrName(result.
name),
2640 if (innerParallel) {
2641 result.
addAttribute(SplitReductionOp::getInnerParallelAttrName(result.
name),
2644 if (useScalingAlgorithm) {
2646 SplitReductionOp::getUseScalingAlgorithmAttrName(result.
name),
2654 result.
addTypes({resultType, resultType, resultType, resultType});
2663 unsigned(getInsertSplitDimension()),
2664 bool(getInnerParallel())};
2667 FailureOr<SplitReductionResult> splitResult =
2668 (getUseScalingAlgorithm())
2671 if (failed(splitResult))
2672 return emitDefaultDefiniteFailure(target);
2674 results.
push_back(splitResult->initOrAlloc);
2676 results.
push_back(splitResult->splitLinalgOp);
2677 results.
push_back(splitResult->resultCombiningLinalgOp);
2685 void transform::TileReductionUsingForOp::build(
2696 build(builder, result,
2699 staticTileSizesAttr);
2708 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2709 if (!partialReductionOp) {
2712 "Operation should implement PartialReductionOpInterface");
2715 rewriter, partialReductionOp,
2719 return emitDefaultSilenceableFailure(target);
2720 rewriter.
replaceOp(target, result->mergeResult.replacements);
2721 for (
Value initValue : result->initialValues)
2723 for (
auto parallelTiledOp : result->tiledOps)
2725 for (
auto mergeOp : result->mergeResult.mergeOps)
2727 results.
push_back(result->loops.front());
2735 void transform::TileReductionUsingForallOp::build(
2738 ArrayAttr mapping) {
2748 build(builder, result,
2751 staticNumThreadsAttr,
2752 staticTileSizesAttr,
2765 FailureOr<linalg::ForallReductionTilingResult> result =
2767 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2768 numThreads, tileSizes, getMapping());
2770 if (failed(result)) {
2771 auto diag = emitSilenceableError() <<
"could not tile reduction";
2772 diag.attachNote(target.getLoc()) <<
"target operation";
2775 for (
Value initValue : result->initialValues)
2777 for (
auto parallelTiledOp : result->parallelTiledOps)
2779 for (
auto mergeOp : result->mergeOps)
2795 llvm::to_vector(state.getPayloadOps(getTarget()));
2797 if (!llvm::hasSingleElement(targetOps)) {
2799 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
2804 auto linalgOp = dyn_cast<LinalgOp>(target);
2805 auto tileableOp = dyn_cast<TilingInterface>(target);
2810 OpBuilder builder(linalgOp.getContext());
2812 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
2813 if (linalgOp.hasDynamicShape()) {
2814 auto diag = emitSilenceableError()
2815 <<
"cannot compute parametric tile sizes for dynamically "
2816 "shaped payload op";
2817 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
2821 FailureOr<StaticContinuousTileSizeSpecification> spec =
2825 return emitSilenceableError()
2826 <<
"failed to compute multi-size tiling sizes";
2831 for (
auto &&[tileSize, tripCount] :
2832 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2833 chunkSizes.push_back(tileSize * tripCount);
2836 return llvm::map_to_vector(values, [&](int64_t value) ->
Attribute {
2840 transformResults.
setParams(cast<OpResult>(getTileSizes()),
2841 getI64AttrsFromI64(spec->tileSizes));
2842 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
2843 getI64AttrsFromI64(chunkSizes));
2851 unsigned dimension = getDimension();
2854 builder, tileableOp, dimension, targetSize,
true);
2856 return emitSilenceableError() <<
"could not generate tile size computation";
2868 for (
auto &&[tileSize, tripCount] :
2869 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2870 splitPoint = apply(s0 * s1, {tileSize, tripCount});
2871 chunkSizes.push_back(splitPoint);
2875 return llvm::map_to_vector(values, [&](
Value value) ->
Operation * {
2880 transformResults.
set(cast<OpResult>(getTileSizes()),
2881 getDefiningOps(spec->tileSizes));
2882 transformResults.
set(cast<OpResult>(getChunkSizes()),
2883 getDefiningOps(chunkSizes));
2891 return emitOpError() <<
"expects all results type to be the same";
2897 void transform::ContinuousTileSizesOp::getEffects(
2899 if (isa<TransformParamTypeInterface>(getTileSizes().
getType()))
2915 Type &tileSizesType,
2916 Type &chunkSizesType) {
2917 FunctionType funcType;
2919 if (failed(parser.
parseType<FunctionType>(funcType)))
2922 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
2923 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
2924 "argument and one result";
2926 targetType = funcType.getInput(0);
2927 tileSizesType = chunkSizesType = funcType.getResult(0);
2936 void transform::TileUsingForOp::build(
2941 return build(builder, result, loopTypes,
2945 interchange, scalableSizes);
2948 void transform::TileUsingForOp::build(
2952 build(builder, result, target,
2954 interchange, scalableSizes);
2957 void transform::TileUsingForOp::build(
2964 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2968 void transform::TileUsingForOp::build(
2980 unsigned numExpectedLoops =
2981 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2983 resultTypes.reserve(numExpectedLoops);
2984 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2985 "expected one loop type or as many as loops");
2986 if (loopTypes.size() == 1)
2987 resultTypes.append(numExpectedLoops, loopTypes[0]);
2989 llvm::append_range(resultTypes, loopTypes);
2991 if (scalableSizes.has_value())
2992 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2993 build(builder, result, target.
getType(),
2997 staticTileSizesAttr,
2999 expandedScalableSizes);
3004 return emitOpError(
"expected same number of sizes (")
3006 << getScalableSizes().size() <<
")";
3008 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3009 if (getLoops().size() != numExpectedLoops)
3010 return emitOpError(
"expected number of loops to tile (")
3011 << numExpectedLoops <<
") to match number of `loops` results ("
3012 << getLoops().size() <<
")";
3023 llvm::to_vector(state.getPayloadOps(getTarget()));
3029 if (isa<ParamType>(transformValue.getType())) {
3030 dynamicSizeProducers.push_back({});
3032 paramSizes.push_back(
3033 llvm::to_vector(llvm::map_range(params, [](
Attribute attr) {
3034 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3037 if (paramSizes.back().size() != targets.size()) {
3039 emitSilenceableError()
3040 <<
"expected as many parameter values ("
3041 << dynamicSizeProducers.back().size() <<
") as target ops ("
3042 << targets.size() <<
")";
3043 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
3049 paramSizes.push_back({});
3050 dynamicSizeProducers.push_back(
3051 llvm::to_vector(state.getPayloadOps(transformValue)));
3053 if (dynamicSizeProducers.back().size() != targets.size()) {
3055 emitSilenceableError()
3056 <<
"expected as many dynamic size-producing operations ("
3057 << dynamicSizeProducers.back().size() <<
") as target ops ("
3058 << targets.size() <<
")";
3059 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3063 for (
Operation *op : dynamicSizeProducers.back()) {
3070 emitSilenceableError() <<
"expected sizes to be produced by ops "
3071 "with a single index-type result";
3072 diag.attachNote(op->
getLoc()) <<
"size producer op";
3073 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3080 loops.resize(getLoops().size());
3081 auto scalableSizes = getScalableSizes();
3083 auto tilingInterface = dyn_cast<TilingInterface>(op);
3084 if (!tilingInterface) {
3086 emitSilenceableError()
3087 <<
"only ops implementing TilingInterface are supported";
3088 diag.attachNote(op->
getLoc()) <<
"target op";
3091 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3093 emitSilenceableError()
3094 <<
"too many tiles provided, expected at most "
3095 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
3096 << tileSizes.size();
3097 diag.attachNote(op->
getLoc()) <<
"target op";
3102 if (tileSizes.empty()) {
3111 sizes.reserve(tileSizes.size());
3112 unsigned dynamicIdx = 0;
3115 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3116 if (scalableSizes[ofrIdx]) {
3117 auto val = b.
create<arith::ConstantIndexOp>(
3118 getLoc(), cast<IntegerAttr>(attr).getInt());
3124 sizes.push_back(attr);
3131 assert((dynamicSizes.empty() ^ params.empty()) &&
3132 "expected either dynamic sizes or parameters");
3133 if (!params.empty()) {
3136 sizes.push_back(dynamicSizes[index]->getResult(0));
3144 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3145 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3146 if (failed(maybeTilingResult))
3149 rewriter.
replaceOp(op, maybeTilingResult->mergeResult.replacements);
3151 tiled.append(maybeTilingResult->tiledOps);
3153 loops[en2.index()].push_back(en2.value());
3156 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3158 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3167 results.reserve(tileSizes.size());
3168 unsigned dynamicPos = 0;
3170 for (int64_t size : tileSizes) {
3171 if (size == ShapedType::kDynamic) {
3172 results.push_back(dynamic[dynamicPos++]);
3180 void transform::TileUsingForOp::getEffects(
3192 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3196 ArrayAttr mapping) {
3197 return build(builder, result,
3205 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3209 ArrayAttr mapping) {
3219 build(builder, result,
3220 TypeRange{operationType, operationType},
3227 staticTileSizesAttr,
3231 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3235 ArrayAttr mapping) {
3236 return build(builder, result, target,
3241 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3245 ArrayAttr mapping) {
3256 build(builder, result,
3257 TypeRange{operationType, operationType},
3263 staticNumThreadsAttr,
3276 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3278 for (
auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3280 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3281 normalizedUbs.push_back(normalizedUb);
3283 return normalizedUbs;
3299 for (
auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3302 denormalizedIvs.push_back(
3305 return denormalizedIvs;
3316 scf::ForallOp loop) {
3336 auto normalizedForallOp = rewriter.
create<scf::ForallOp>(
3337 loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3340 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3342 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3347 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3348 normalizedForallOp.getRegionIterArgs().end());
3349 Block *origLoopBlock = loop.getBody();
3350 rewriter.
mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3352 rewriter.
replaceOp(loop, normalizedForallOp);
3353 return normalizedForallOp;
3358 TransformOpInterface transformOp,
Operation *target,
3363 auto tileableOp = dyn_cast<TilingInterface>(target);
3366 transformOp.emitSilenceableError()
3367 <<
"only TilingInterface ops are supported";
3368 diag.attachNote(target->
getLoc()) <<
"target op";
3374 if (!mixedNumThreads.empty()) {
3375 options.setNumThreads(mixedNumThreads);
3377 options.setTileSizes(mixedTileSizes);
3380 options.setMapping(mapping.value().getValue());
3382 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3385 if (failed(maybeTilingResult))
3386 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3388 rewriter.
replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
3390 tilingResult = *maybeTilingResult;
3392 if (mixedNumThreads.empty()) {
3393 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.
loops.front());
3396 scf::ForallOp normalizedForallOp =
3398 tilingResult.
loops.front() = normalizedForallOp;
3408 auto transformOp = cast<TransformOpInterface>(getOperation());
3417 getPackedNumThreads()
3419 state, transformOp, mixedNumThreads, getPackedNumThreads())
3421 state, transformOp, mixedNumThreads, getMixedNumThreads());
3425 status = getPackedTileSizes()
3427 state, transformOp, mixedTileSizes, getPackedTileSizes())
3429 state, transformOp, mixedTileSizes, getMixedTileSizes());
3433 for (
Operation *target : state.getPayloadOps(getTarget())) {
3436 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3437 getMapping(), tilingResult);
3438 if (!
diag.succeeded())
3440 tileOps.push_back(tilingResult.
loops.front());
3441 tiledOps.append(tilingResult.
tiledOps);
3444 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3445 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3450 void transform::TileUsingForallOp::getEffects(
3463 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3472 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3473 static_cast<int>(getPackedNumThreads() !=
Value());
3474 if (numThreadsSpec > 1)
3476 "num_threads and packed_num_threads are mutually exclusive");
3477 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3478 static_cast<int>(getPackedTileSizes() !=
Value());
3479 if (tileSizesSpec > 1)
3481 "tile_sizes and packed_tile_sizes are mutually exclusive");
3482 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3483 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
3484 "must be specified");
3492 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3494 bool vectorizePadding,
bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
3496 if (vectorizePadding) {
3498 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3502 if (vectorizeExtract) {
3504 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3508 if (flatten1DDepthwiseConv) {
3510 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3521 explicit VectorizationPattern(
MLIRContext *context,
3522 bool vectorizeExtract =
false,
3523 bool flattenConv =
false)
3525 vectorizeNDExtract(vectorizeExtract),
3526 flatten1DDepthwiseConv(flattenConv) {}
3527 LogicalResult matchAndRewrite(
Operation *op,
3531 "Unsupported Op, cannot vectorize");
3533 {}, vectorizeNDExtract,
3534 flatten1DDepthwiseConv);
3540 bool vectorizeNDExtract =
false;
3544 bool flatten1DDepthwiseConv =
false;
3549 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3554 auto diag = this->emitOpError(
"requires isolated-from-above targets");
3555 diag.attachNote(target->
getLoc()) <<
"non-isolated target";
3561 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3562 getFlatten_1dDepthwiseConv());
3564 if (!getDisableTransferPermutationMapLoweringPatterns())
3567 if (!getDisableMultiReductionToContractPatterns())
3575 vector::TransferReadOp::getCanonicalizationPatterns(
patterns, ctx);
3576 vector::TransferWriteOp::getCanonicalizationPatterns(
patterns, ctx);
3581 if (getVectorizePadding()) {
3591 config.listener = &listener;
3593 return emitDefaultDefiniteFailure(target);
3607 auto targets = state.getPayloadOps(getTarget());
3608 if (std::empty(targets))
3610 auto transformOp = cast<TransformOpInterface>(getOperation());
3613 state, transformOp, getMixedVectorSizes(), vectorSizes);
3621 <<
"Unsupported Op, cannot vectorize";
3626 getVectorizeNdExtract().value_or(
false)))) {
3628 <<
"Attempted to vectorize, but failed";
3635 void transform::VectorizeOp::getEffects(
3644 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3648 if (getStaticVectorSizes().size() != getScalableSizes().size())
3649 return emitOpError(
"expected same number of vector sizes (")
3650 << getStaticVectorSizes().size() <<
") and scalable sizes ("
3651 << getScalableSizes().size() <<
")";
3660 transform::HoistRedundantVectorTransfersOp::applyToOne(
3677 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3696 auto maybeTransformed =
3699 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3702 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3705 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3708 .Case([&](linalg::Conv2DNchwFchwOp op) {
3714 if (failed(maybeTransformed))
3715 return emitDefaultSilenceableFailure(target);
3717 results.
push_back(maybeTransformed->first);
3719 results.
push_back(maybeTransformed->second);
3734 <<
"only elementwise flattening is supported";
3737 if (target.getNumLoops() <= 1) {
3744 std::iota(reassociation.begin(), reassociation.end(), 0);
3745 auto maybeFlattened =
3747 if (failed(maybeFlattened))
3749 <<
"attempted to flatten, but failed";
3750 results.
push_back(maybeFlattened->collapsedOp);
3751 rewriter.
replaceOp(target, maybeFlattened->results);
3764 auto maybeTransformed =
3766 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3769 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3775 if (failed(maybeTransformed))
3776 return emitDefaultSilenceableFailure(target);
3791 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3792 auto maybeTransformed =
3794 .Case([&](linalg::MatmulOp op) {
3797 .Case([&](linalg::BatchMatmulOp op) {
3800 .Default([&](
Operation *op) {
return failure(); });
3801 if (failed(maybeTransformed))
3811 template <
typename OpTy>
3815 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3816 tensor::ParallelInsertSliceOp>() &&
3819 if (
auto copySource =
3820 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3827 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3829 target->template getParentOfType<scf::InParallelOp>());
3832 Value extracted = rewriter.
create<tensor::ExtractSliceOp>(
3833 target.getLoc(), target.getDest(), target.getMixedOffsets(),
3834 target.getMixedSizes(), target.getMixedStrides());
3835 Value copied = rewriter
3836 .
create<linalg::CopyOp>(target.getLoc(),
3837 target.getSource(), extracted)
3842 target, copied, target.getDest(), target.getMixedOffsets(),
3843 target.getMixedSizes(), target.getMixedStrides());
3845 results.
push_back(copied.getDefiningOp());
3855 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3856 return doit(rewriter, target, results, state);
3857 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3858 return doit(rewriter, target, results, state);
3861 emitSilenceableError()
3862 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3863 diag.attachNote(targetOp->
getLoc()) <<
"target op";
3876 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3878 emitSilenceableError()
3879 <<
"only linalg.copy and tensor.pad target ops are supported";
3880 diag.attachNote(target->
getLoc()) <<
"target op";
3883 assert(target->
getNumResults() == 1 &&
"expected single result");
3885 if (!resultShapedType.hasStaticShape()) {
3887 emitSilenceableError()
3888 <<
"only statically sized ops of rank <= 3 are supported";
3889 diag.attachNote(target->
getLoc()) <<
"target op";
3894 int64_t desiredBitAlignment = getDesiredBitAlignment();
3895 int64_t eltBitwidth =
3896 resultShapedType.getElementType().getIntOrFloatBitWidth();
3897 if (desiredBitAlignment % eltBitwidth != 0) {
3898 desiredBitAlignment = eltBitwidth;
3903 getTotalNumThreads(),
3904 desiredBitAlignment,
3905 resultShapedType.getShape(),
3908 resultShapedType.getElementType().getIntOrFloatBitWidth());
3909 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3911 emitSilenceableError()
3912 <<
"too few threads to map copy op to threads on the most minor "
3913 "dimension, given alignment and vector size constraints, try "
3914 "smaller tile size of mapping to more threads";
3915 diag.attachNote(target->
getLoc()) <<
"target op";
3931 if (!
diag.succeeded())
3935 for (
auto op : tilingResult.
tiledOps)
3949 FailureOr<Operation *> maybeTransformed = failure();
3951 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3956 .Default([&](
Operation *op) {
return false; });
3959 return emitSilenceableError()
3960 <<
"this operation is not supported to convert to Winograd Conv2D";
3963 if (failed(maybeTransformed)) {
3964 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
3976 FailureOr<Operation *> maybeTransformed = failure();
3979 .Case([&](linalg::WinogradFilterTransformOp op) {
3983 .Case([&](linalg::WinogradInputTransformOp op) {
3987 .Case([&](linalg::WinogradOutputTransformOp op) {
3991 .Default([&](
Operation *op) {
return false; });
3995 emitSilenceableError()
3996 <<
"this operation is not supported to decompose into other operations";
3997 diag.attachNote(target->
getLoc()) <<
"target op";
4001 if (failed(maybeTransformed)) {
4003 emitSilenceableError() <<
"decompose Winograd operations failed";
4004 diag.attachNote(target->
getLoc()) <<
"target op";
4012 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4014 #define GET_OP_CLASSES
4015 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
This class represents a saved insertion point.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
uint64_t getM(LevelType lt)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
@ MaterializeInDestination
Transformation to drop unit-extent dimensions from linalg.generic operations.
Vectorization pattern for memref::CopyOp.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Match and rewrite for the pattern:
Match and rewrite for the pattern:
@ BufferizationMaterializeInDestination
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.