43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Debug.h"
47 #include <type_traits>
53 #define DEBUG_TYPE "linalg-transforms"
54 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55 #define DBGSNL() (llvm::dbgs() << "\n")
56 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
63 template <
typename PatternTy,
typename... Args>
66 using OpTy =
typename llvm::function_traits<
67 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
68 auto op = dyn_cast<OpTy>(operation);
73 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
78 explicit TrivialPatternRewriter(
MLIRContext *context)
81 TrivialPatternRewriter rewriter(operation->
getContext());
82 rewriter.setInsertionPoint(operation);
83 auto result = pattern.returningMatchAndRewrite(op, rewriter);
86 return cast<LinalgOp>(result->getOperation());
97 if (!isa<IntegerAttr>(ofr.get<
Attribute>()))
98 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
99 result.push_back(ofr);
104 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
106 if (params.size() != 1)
107 return transformOp.emitDefiniteFailure()
108 <<
"requires exactly one parameter associated";
109 result.push_back(params[0]);
113 auto payloadOps = state.getPayloadOps(transformValue);
114 if (!llvm::hasSingleElement(payloadOps)) {
116 transformOp.emitSilenceableError()
117 <<
"handle must be mapped to exactly one payload op";
119 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
126 transformOp.emitSilenceableError()
127 <<
"payload op must have exactly 1 index result";
147 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
149 for (
auto param : params) {
150 if (!isa<IntegerAttr>(param))
151 return transformOp.emitDefiniteFailure()
152 <<
"expected the parameter to be associated with an integer "
154 result.push_back(param);
159 for (
Operation *op : state.getPayloadOps(packedHandle)) {
162 transformOp.emitSilenceableError()
163 <<
"payload op must have exactly 1 index result";
182 if (isa<Attribute>(paramOrHandle)) {
184 cast<IntegerAttr>(paramOrHandle.get<
Attribute>()).getInt());
186 }
else if (isa<ParamType>(paramOrHandle.get<
Value>().
getType())) {
188 if (params.size() != 1)
189 return transformOp.emitSilenceableError() <<
"expected a single param";
191 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
196 if (!isa<TransformHandleTypeInterface>(handle.
getType()))
197 return transformOp.emitSilenceableError() <<
"unexpected value handle";
198 auto payload = state.getPayloadOps(handle);
199 if (!llvm::hasSingleElement(payload))
200 return transformOp.emitSilenceableError()
201 <<
"requires param or handle that is mapped to 1 payload op";
203 Operation *paramOrHandlePayloadOp = *payload.begin();
206 return transformOp.emitSilenceableError()
207 <<
"requires param or handle to be result of op with 1 index "
213 return transformOp.emitSilenceableError()
214 <<
"requires param or handle to be the result of a constant like "
217 reified.push_back(attr.getInt());
226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
231 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
237 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
240 options.rankReductionStrategy =
245 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
254 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
259 resultTypes.push_back(b.
getType<transform::AnyValueType>());
260 resultTypes.push_back(b.
getType<transform::AnyOpType>());
261 return build(b, result,
267 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
270 int64_t memorySpace) {
272 resultTypes.push_back(b.
getType<transform::AnyValueType>());
273 resultTypes.push_back(b.
getType<transform::AnyOpType>());
274 return build(b, result,
290 void notifyOperationInserted(
Operation *op,
292 ForwardingListener::notifyOperationInserted(op, previous);
294 if (previous.
isSet())
296 auto inserted = newOps.insert(op);
298 assert(inserted.second &&
"expected newly created op");
301 void notifyOperationErased(
Operation *op)
override {
302 ForwardingListener::notifyOperationErased(op);
316 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
317 NewOpsListener newOpsListener(previousListener);
321 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
324 }
else if (getMemcpyOp() ==
"memref.copy") {
327 }
else if (getMemcpyOp() ==
"linalg.copy") {
331 llvm_unreachable(
"invalid memcpy op");
333 if (getAllocOp() ==
"memref.alloc") {
336 }
else if (getAllocOp() ==
"memref.alloca") {
340 llvm_unreachable(
"invalid alloc op");
342 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
343 options.emitDealloc = getEmitDealloc();
347 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
349 for (
Operation *op : state.getPayloadOps(getTarget())) {
354 <<
"failed to bufferize operation";
355 diag.attachNote(op->
getLoc()) <<
"target payload op";
358 allocatedBuffers.push_back(buffer);
362 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
363 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
367 void transform::BufferizeToAllocationOp::getEffects(
369 if (getBufferizeDestinationOnly()) {
381 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
382 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
383 return emitOpError() <<
"unsupported memcpy op";
384 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
385 return emitOpError() <<
"unsupported alloc op";
398 #define DOWNSCALE(trans) \
400 FailureOr<LinalgOp> res = tryApply<trans>(target); \
401 if (succeeded(res)) { \
402 results.push_back(*res); \
403 return DiagnosedSilenceableFailure::success(); \
407 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
408 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
421 #undef DOWNSCALE_NORMAL
422 #undef DOWNSCALE_CALL
424 return emitDefaultSilenceableFailure(target);
438 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
439 if (!decomposableOp) {
441 "payload is not a decomposable op"));
442 return emitDefaultSilenceableFailure(target);
445 FailureOr<SmallVector<Value>> maybeNewResults =
446 decomposableOp.decomposeOperation(rewriter);
447 if (failed(maybeNewResults))
448 return emitDefaultSilenceableFailure(target);
450 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
451 for (
Value val : *maybeNewResults) {
452 Operation *definition = val.getDefiningOp();
463 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
470 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
474 options.allowReturnAllocsFromLoops =
true;
476 for (
Operation *target : state.getPayloadOps(getTarget())) {
480 <<
"failed to analyze op";
482 rewriter, target, state)))
484 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
495 template <
typename Range>
499 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
505 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
506 if (!tilingInterfaceOp)
507 return transformOp->
emitError(
"only TilingInterface ops are supported");
510 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
511 applyFn(tilingInterfaceOp);
512 if (failed(tiledResults))
517 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
518 for (
Operation *toReplace : opsToReplace) {
519 for (
OpResult res : toReplace->getResults())
520 if (
auto replacement = tiledResults->replacements.lookup(res))
522 if (toReplace->use_empty()) {
528 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
529 assert(tiledResults->loops.size() == numLoops &&
530 "Mismatched number of loops, tile and fuse transform should have "
532 for (
unsigned int i = 0; i < numLoops; ++i)
533 loopOps[i].push_back(tiledResults->loops[i]);
536 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
537 for (
unsigned int i = 0; i < numLoops; ++i)
538 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
548 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
550 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
556 tilingOptions = tilingOptions.
setTileSizes(tileSizesOfr);
560 rewriter, getOperation(), state.getPayloadOps(getTarget()),
561 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
562 [&](TilingInterface tilingInterfaceOp)
563 -> FailureOr<scf::SCFTileAndFuseResult> {
564 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
573 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
574 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
575 if (!std::is_permutation(sequence.begin(), sequence.end(),
576 permutation.begin(), permutation.end())) {
577 return emitOpError() <<
"expects interchange to be a permutation, found "
578 << getTileInterchange();
582 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
583 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
584 if (numExpectedLoops != getNumResults() - 1)
585 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
594 void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
597 Value containingOp) {
600 result.
addTypes({resultType, resultType});
616 (domInfo.
dominates(containingOp, user))) {
617 dominatedUsers.insert(user);
620 if (dominatedUsers.empty())
624 auto forallOp = cast<scf::ForallOp>(containingOp);
630 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
635 newOuts.push_back(outputs[resultNumber]);
638 auto newforallOp = rewriter.
create<scf::ForallOp>(
639 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
640 forallOp.getMixedStep(), newOuts, forallOp.getMapping());
642 newforallOp.getRegion().takeBody(forallOp.getRegion());
647 newforallOp.getBody()->addArgument(newOuts.back().getType(),
648 newOuts.back().getLoc());
649 auto bbArgs = newforallOp.getBody()->getArguments();
657 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
659 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; }));
660 Operation *firstYieldOp = yieldingOps.front();
663 Value dst = newforallOp.getRegionIterArgs().back();
665 rewriter.
create<tensor::ParallelInsertSliceOp>(firstYieldOp->
getLoc(), src,
666 dst, offsets, sizes, strides);
670 newforallOp->getResult(result.index()));
673 newforallOp->getResults().back(),
676 return dominatedUsers.contains(user);
687 static std::tuple<SmallVector<Operation *>,
Operation *>
690 LLVM_DEBUG(
DBGS() <<
"Try to fuse a direct extract use\n");
691 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
692 if (!tileableProducer) {
694 <<
"producer is not a TileableInterface: " << *producerOp;
701 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
702 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
703 return sliceOp && containingOp->isProperAncestor(sliceOp);
707 if (it == tileableProducer->getUsers().end()) {
708 diag.attachNote(tileableProducer->getLoc())
709 <<
"could not find fusion opportunity for: " << *tileableProducer;
712 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
719 int64_t resultNumber =
720 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
721 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
726 FailureOr<TilingResult> tileAndFuseResult =
727 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
730 if (failed(tileAndFuseResult)) {
731 diag.attachNote(tileableProducer->getLoc())
732 <<
"failed to tile producer op: " << *tileableProducer;
737 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
738 LLVM_DEBUG(
DBGS() <<
"tiledProducer: " << *tiledOp <<
"\n");
743 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
744 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
745 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
746 if (failed(maybeRankReduced)) {
748 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
749 << tileAndFuseResult->tiledValues[0]
750 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
753 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
757 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
758 resultNumber, offsets, sizes);
760 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
773 LLVM_DEBUG(
DBGS() <<
"Try to fuse an extract use through block argument\n");
775 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
776 if (!tileableProducer) {
778 <<
"producer is not a TileableInterface: " << *producerOp;
783 scf::ForallOp forallOp;
784 auto itProducerUses =
785 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
786 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
790 if (!forallOp || forallOp != containingOp) {
791 diag.attachNote(tileableProducer->getLoc())
792 <<
"could not find a use by the containing op: " << *tileableProducer;
807 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
808 return sliceOp && containingOp->isProperAncestor(sliceOp);
812 if (itBBArgUsers == bbArg.
getUsers().end()) {
814 <<
"could not find fusion opportunity for bbArg: " << bbArg;
817 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
825 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
826 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
831 rewriter, tileableProducer->getLoc(), tileableProducer,
832 destinationTensors))) {
833 diag.attachNote(tileableProducer->getLoc())
834 <<
"failed to get destination tensors for: " << *tileableProducer;
839 bvm.
map(destinationTensors[resultNumber], bbArg);
840 auto tileableProducerClone =
841 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
843 llvm::make_scope_exit([&]() { rewriter.
eraseOp(tileableProducerClone); });
846 FailureOr<TilingResult> tileAndFuseResult =
847 tileableProducerClone.generateResultTileValue(
848 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
849 sliceOpToTile.getMixedSizes());
850 if (failed(tileAndFuseResult)) {
851 diag.attachNote(tileableProducer->getLoc())
852 <<
"failed to tile producer op: " << *tileableProducer;
857 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
858 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
859 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
860 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
861 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
866 destinationTensors.front());
869 return tileAndFuseResult->tiledOps;
875 LLVM_DEBUG(
DBGS() <<
"Try to fuse an use by cloning\n");
880 for (
OpOperand &use : result.getUses()) {
882 uses.push_back(&use);
887 if (containingOp == use.getOwner()) {
889 <<
"producer op use by containing op cannot be fused by cloning";
897 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
906 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
907 "Parallel insert slice is not a valid clone destination");
908 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
909 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
913 fusedOp = rewriter.
clone(*producerOp);
915 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
920 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
930 auto producerOps = state.getPayloadOps(getProducerOp());
931 auto containingOps = state.getPayloadOps(getContainingOp());
932 if (!llvm::hasSingleElement(containingOps)) {
934 <<
"requires exactly one containing_op handle (got "
935 << llvm::range_size(containingOps) <<
")";
937 Operation *containingOp = *containingOps.begin();
940 if (std::empty(producerOps)) {
942 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
950 auto getNextProducer = [&]() -> FailureOr<Operation *> {
951 for (
const auto &it :
enumerate(remainingProducers)) {
954 int64_t numUsesInContainingOp =
956 return containingOp->isAncestor(op);
961 if (numUsesInContainingOp > 0) {
962 if (numUsesInContainingOp == 1)
963 remainingProducers.erase(remainingProducers.begin() + it.index());
970 while (!remainingProducers.empty()) {
971 auto nextProducer = getNextProducer();
972 if (failed(nextProducer)) {
974 <<
"could not find next producer to fuse into container";
975 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
983 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
990 auto [tiledOps, newContainingOp] =
992 if (!tiledOps.empty()) {
993 LLVM_DEBUG(
DBGS() <<
"\nFused a direct extract use\n" << *containingOp);
994 fusedOps.append(tiledOps);
995 if (newContainingOp) {
1003 LogicalResult replacementStatus =
1006 (void)replacementStatus;
1007 assert(succeeded(replacementStatus) &&
1008 "unable to update transform state mapping");
1009 rewriter.
eraseOp(containingOp);
1010 containingOp = newContainingOp;
1017 rewriter,
diag, producerOp, containingOp);
1018 if (!tiledContainingOpOperand.empty()) {
1019 LLVM_DEBUG(
DBGS() <<
"\nFused an extract use through block argument\n"
1021 fusedOps.append(tiledContainingOpOperand);
1028 LLVM_DEBUG(
DBGS() <<
"\nFused an use by cloning\n" << *containingOp);
1029 fusedOps.push_back(cloned);
1035 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1036 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1040 void transform::FuseIntoContainingOp::getEffects(
1058 if (isa<GenericOp>(target)) {
1064 if (succeeded(
generic)) {
1065 results.
push_back(generic->getOperation());
1068 return emitDefaultSilenceableFailure(target);
1081 if (!isa<GenericOp>(target)) {
1086 FailureOr<LinalgOp> named =
1088 if (succeeded(named)) {
1089 results.
push_back(named->getOperation());
1092 return emitDefaultSilenceableFailure(target);
1106 if (interchangeVector.empty()) {
1111 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1112 if (interchangeVector.size() != numLoops) {
1113 return emitSilenceableError()
1114 << getIteratorInterchangeAttrName() <<
" has length ("
1115 << interchangeVector.size()
1116 <<
") different from the number of loops in the target operation ("
1119 FailureOr<GenericOp> res =
1122 interchangeVector.end()));
1131 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1132 if (!std::is_permutation(sequence.begin(), sequence.end(),
1133 permutation.begin(), permutation.end())) {
1134 return emitOpError()
1135 <<
"expects iterator_interchange to be a permutation, found "
1136 << getIteratorInterchange();
1150 FailureOr<LowerPackResult> res =
lowerPack(rewriter, target);
1153 <<
"cannot lower to pad + expand + transpose";
1156 transformResults.
push_back(res->expandShapeOp);
1157 transformResults.
push_back(res->transposeOp);
1170 FailureOr<LowerUnPackOpResult> res =
lowerUnPack(rewriter, target);
1173 emitSilenceableError()
1174 <<
"cannot lower to transpose + collapse + extract";
1175 diag.attachNote(target->getLoc()) <<
"target payload op";
1178 transformResults.
push_back(res->emptyOp);
1179 transformResults.
push_back(res->transposeOp);
1180 transformResults.
push_back(res->collapseShapeOp);
1181 transformResults.
push_back(res->extractSliceOp);
1211 if (getOps().has_value())
1212 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1213 getOps()->getAsValueRange<StringAttr>().end());
1215 auto payloadOps = state.getPayloadOps(getTarget());
1216 if (!llvm::hasSingleElement(payloadOps)) {
1221 bool incorrectNumOperandTypes =
false;
1228 if (getInterface().has_value()) {
1229 auto iface = getInterface().value();
1230 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1233 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1234 !isa<TilingInterface>(op))
1236 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1237 !isa<LoopLikeOpInterface>(op))
1242 if (getOpAttrs().has_value()) {
1243 DictionaryAttr opAttrs = getOpAttrs().value();
1245 if (attr.getName() == getInterfaceAttrName() ||
1246 attr.getName() == getOpsAttrName())
1248 if (!op->
hasAttr(attr.getName()))
1250 if (op->
getAttr(attr.getName()) != attr.getValue())
1255 if (getFilterResultType().has_value()) {
1256 Type t = getFilterResultType().value();
1261 if (getFilterOperandTypes().has_value()) {
1262 mlir::ArrayAttr types = getFilterOperandTypes().value();
1265 if (types.size() == 1) {
1268 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1269 Type t = cast<::mlir::Type>(typeattr.getValue());
1271 [&](
Type operandType) { return operandType == t; }))
1276 if (types.size() != operandTypes.size()) {
1277 incorrectNumOperandTypes =
true;
1281 for (
auto [attr, operandType] :
1282 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1283 auto typeattr = cast<mlir::TypeAttr>(attr);
1284 Type type = cast<::mlir::Type>(typeattr.getValue());
1286 if (type != operandType)
1297 (*payloadOps.begin())->
walk(matchFun);
1298 if (incorrectNumOperandTypes)
1300 "type, then it must contain as much types as "
1301 "the number of operands in the target ops");
1302 results.
set(cast<OpResult>(getResult()), res);
1317 Type &targetType,
Type &lowSizeType,
1319 Type &splitPointType) {
1320 FunctionType funcType;
1322 if (failed(parser.
parseType<FunctionType>(funcType)))
1325 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1326 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1327 "argument and one result";
1329 targetType = funcType.getInput(0);
1330 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1338 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1339 if (target.hasDynamicShape()) {
1340 auto diag = emitSilenceableError()
1341 <<
"cannot compute parametric tile sizes for dynamically "
1342 "shaped payload op";
1343 diag.attachNote(target->getLoc()) <<
"payload op";
1348 target, getDimension(), getTargetSize(), getDivisor());
1350 return emitSilenceableError()
1351 <<
"failed to compute multi-size tiling sizes";
1354 Builder builder(target.getContext());
1355 results.
assign(llvm::map_range(
1357 spec->lowTileSize * spec->lowTripCount}),
1358 [&builder,
this](int64_t value) {
1370 builder, target, getDimension(), targetSize, divisor);
1372 return emitSilenceableError() <<
"could not generate tile size computation";
1379 {spec->lowTileSize, spec->lowTripCount});
1380 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1381 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1382 assert(lowTileSize && highTileSize && splitPoint &&
1383 "tile sizes are not produced by operations");
1391 void transform::MultiTileSizesOp::getEffects(
1395 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1404 return emitOpError() <<
"expects all results type to be the same";
1424 builder.
getContext(), GenericOp::getOperationName());
1425 build(builder, result,
1434 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1441 auto targetOps = state.getPayloadOps(getTarget());
1443 if (std::empty(targetOps)) {
1444 transformResults.
set(cast<OpResult>(getPackedOp()),
1449 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1450 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1451 return emitSilenceableError()
1452 <<
"requires target to map to exactly 1 LinalgOp (got "
1453 << llvm::range_size(targetOps) <<
")";
1456 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1457 return emitSilenceableError()
1458 <<
"requires number of packed sizes match the number of loops ("
1459 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1466 state, *
this, packedSizes, getMixedPackedSizes());
1469 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1470 if (failed(maybeResult))
1473 transformResults.
set(cast<OpResult>(getPackedOp()),
1474 {maybeResult->packedLinalgOp.getOperation()});
1478 void transform::PackOp::getEffects(
1492 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1493 <<
" is not a valid permutation";
1496 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1497 for (
auto [s, nmo] :
1498 llvm::zip_equal(getMixedMatmulPackedSizes(),
1499 getMatmulPaddedSizesNextMultipleOf())) {
1502 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1503 return emitOpError() <<
"at most one of the packed_size and the "
1504 "padded_sizes_next_multiple_of can be nonzero "
1505 "for the matmul strategy";
1517 for (
Operation *op : state.getPayloadOps(getTarget())) {
1518 auto linalgOp = dyn_cast<LinalgOp>(op);
1529 getMixedMatmulPackedSizes(),
1531 getMatmulPaddedSizesNextMultipleOf(),
1532 getMatmulInnerDimsOrder());
1533 if (succeeded(packResult)) {
1534 results.push_back(packResult->packedLinalgOp);
1537 results.push_back(linalgOp);
1539 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1545 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1549 void transform::PackGreedilyOp::getEffects(
1563 return emitOpError() << getInnerPermAttrName()
1564 <<
" is not a valid permutation";
1567 return emitOpError() << getOuterPermAttrName()
1568 <<
" is not a valid permutation";
1570 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1571 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1572 <<
" or " << getOuterPermAttrName()
1573 <<
" must be specified";
1579 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1589 template <
typename RelayoutOpTy>
1592 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1594 llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
1595 "applies to only pack or unpack operations");
1596 if (!op || permutation.empty())
1598 size_t innerRank = op.getInnerDimsPos().size();
1599 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1603 if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
1604 return permutation.size() == op.getSourceRank() &&
1607 return permutation.size() == op.getDestRank() &&
1615 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1616 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1618 if (std::empty(packOrUnpackOps)) {
1619 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1620 transformResults.
set(cast<OpResult>(getPackOp()), {});
1621 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1627 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1628 !llvm::hasSingleElement(linalgOps)) {
1629 return emitSilenceableError()
1630 <<
"requires target to map to exactly 1 "
1631 "packing op and 1 packed op ("
1632 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
1633 << llvm::range_size(linalgOps) <<
")";
1637 auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
1638 auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
1639 if ((!packOp && !unPackOp)) {
1640 return emitSilenceableError() <<
"requires target to map to a "
1641 "tensor.pack or tensor.unpack";
1643 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1644 if (!linalgOpTarget)
1645 return emitSilenceableError() <<
"requires a LinalgOp target";
1649 if (packOp && packOp.getResult().hasOneUse())
1650 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1652 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1653 if (linalgOp != linalgOpTarget) {
1655 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
1656 : StringLiteral{
"not produced by the LinalgOp target"};
1657 return emitSilenceableError() << errorMsg;
1663 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
1664 OpOperand *packUse = linalgOp.getDpsInitOperand(
1665 cast<OpResult>(unPackOp.getSource()).getResultNumber());
1667 if (!packOp || !packOp.getResult().hasOneUse())
1668 return emitSilenceableError() <<
"could not find matching pack op";
1672 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1674 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1675 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1676 ? StringLiteral{
"invalid outer_perm"}
1677 : StringLiteral{
"invalid inner_perm"};
1681 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1682 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
1688 assert(packOp && linalgOp &&
"unexpected null op");
1692 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1694 assert(succeeded(res) &&
"unexpected packTranspose failure");
1697 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1698 transformResults.
set(cast<OpResult>(getPackedOp()),
1699 {res->transposedLinalgOp});
1701 transformResults.
set(cast<OpResult>(getUnPackOp()),
1702 {res->transposedUnPackOp});
1704 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1719 StringRef copyBackOp) {
1729 (padToMultipleOf.empty()
1731 : b.getDenseI64ArrayAttr(padToMultipleOf)),
1732 b.getI64ArrayAttr(packPaddings),
1733 b.getArrayAttr(transposePaddings),
1734 b.getStringAttr(copyBackOp));
1742 StringRef copyBackOp) {
1747 staticPadToMultipleOf);
1754 dynamicPadToMultipleOf,
1755 staticPadToMultipleOf,
1761 void PadOp::getEffects(
1771 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1778 auto transformOp = cast<TransformOpInterface>(getOperation());
1781 for (
Operation *target : state.getPayloadOps(getTarget())) {
1782 auto linalgTarget = dyn_cast<LinalgOp>(target);
1783 if (!linalgTarget) {
1784 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
1785 diag.attachNote(target->
getLoc()) <<
"target op";
1791 for (int64_t packPadding :
1792 extractFromIntegerArrayAttr<int64_t>(getPackPaddings()))
1793 packPaddings.push_back(
static_cast<bool>(packPadding));
1797 for (
auto const &it :
1798 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1799 auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1801 emitOpError(
"expects padding values to be typed attributes");
1806 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
1810 if (!parsedAttr || parsedAttr.getType() != elementType) {
1811 auto diag = this->emitOpError(
"expects a padding that parses to ")
1812 << elementType <<
", got " << std::get<0>(it);
1813 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
1816 paddingValues.push_back(parsedAttr);
1820 if (attr.getType() != elementType) {
1821 auto diag = this->emitOpError(
"expects a padding value of type ")
1822 << elementType <<
", got " << attr;
1823 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
1826 paddingValues.push_back(attr);
1831 for (
Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1832 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1833 cast<ArrayAttr>(transposeVector)));
1838 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1842 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
1845 if (padToMultipleOf.empty())
1849 options.padToMultipleOf = padToMultipleOf;
1850 options.paddingValues = paddingValues;
1851 options.packPaddings = packPaddings;
1852 if (getCopyBackOp() ==
1853 bufferization::MaterializeInDestinationOp::getOperationName()) {
1856 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1858 }
else if (getCopyBackOp() == kCopyOpNone) {
1861 llvm_unreachable(
"unsupported copy_back op");
1867 replacements, newPadOps))) {
1868 auto diag = emitSilenceableError() <<
"failed to pad op";
1869 diag.attachNote(target->
getLoc()) <<
"target op";
1878 rewriter.
replaceOp(linalgTarget, replacements);
1879 paddedOps.push_back(paddedOp);
1880 padOps.append(newPadOps.begin(), newPadOps.end());
1882 for (
Value v : replacements) {
1883 Operation *copyBackOp = v.getDefiningOp();
1884 if (!llvm::is_contained(copyBackOps, copyBackOp))
1885 copyBackOps.push_back(copyBackOp);
1890 results.
set(cast<OpResult>(getPadded()), paddedOps);
1891 results.
set(cast<OpResult>(getPad()), padOps);
1892 results.
set(cast<OpResult>(getCopy()), copyBackOps);
1898 extractFromIntegerArrayAttr<int64_t>(getPackPaddings());
1899 if (any_of(packPaddings, [](int64_t packPadding) {
1900 return packPadding != 0 && packPadding != 1;
1902 return emitOpError()
1903 <<
"expects pack_paddings to contain booleans (0/1), found "
1904 << getPackPaddings();
1908 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1909 if (any_of(paddingDimensions,
1910 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
1911 return emitOpError() <<
"expects padding_dimensions to contain positive "
1913 << getPaddingDimensions();
1915 if (!getMixedPadToMultipleOf().empty()) {
1916 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
1917 return emitOpError() <<
"expects as many multiples as padding_dimensions";
1920 ArrayAttr transposes = getTransposePaddings();
1923 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
1924 if (!std::is_permutation(sequence.begin(), sequence.end(),
1926 return emitOpError()
1927 <<
"expects transpose_paddings to be a permutation, found "
1931 if (getCopyBackOp() !=
1932 bufferization::MaterializeInDestinationOp::getOperationName() &&
1933 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1934 getCopyBackOp() != kCopyOpNone)
1935 return emitOpError() <<
"invalid copy_back_op";
1947 auto targetOps = state.getPayloadOps(getTarget());
1948 auto loopOps = state.getPayloadOps(getLoop());
1949 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1951 <<
"requires exactly one target and one loop handle (got "
1952 << llvm::range_size(targetOps) <<
" and "
1953 << llvm::range_size(loopOps) <<
")";
1956 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1957 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
1958 if (!padOp || !loopOp)
1961 FailureOr<linalg::detail::PackingResult> result =
1967 if (result->clonedLoopIvs.empty()) {
1968 transformResults.
set(cast<OpResult>(getPackingLoop()),
1969 {result->hoistedPadOp.getOperation()});
1972 auto outerPackedLoop =
1974 transformResults.
set(cast<OpResult>(getPackingLoop()),
1975 {outerPackedLoop.getOperation()});
1981 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
1982 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
1984 return emitOpError() <<
"expects transpose to be a permutation, found "
1990 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2000 tensor::PadOp target,
2003 tensor::PadOp hoistedPadOp;
2005 FailureOr<Value> result =
2007 hoistedPadOp, transposeOps);
2008 if (succeeded(result)) {
2018 return emitDefaultSilenceableFailure(target);
2023 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2024 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
2026 return emitOpError() <<
"expects transpose to be a permutation, found "
2042 if (!getOperandsToPromote().empty())
2044 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2045 if (getUseFullTilesByDefault())
2047 getUseFullTilesByDefault());
2049 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2050 if (!getUseFullTileBuffers().empty())
2052 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2053 if (getAlignment().has_value())
2054 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2055 if (getMemorySpace().has_value())
2056 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2058 if (getMapping().has_value()) {
2060 auto mapping = *getMapping();
2061 if (mapping.size() > 1)
2062 return emitDefaultDefiniteFailure(target);
2064 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2066 if (addressSpace.getAddressSpace() ==
2067 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2074 }
else if (addressSpace.getAddressSpace() ==
2075 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2083 return emitDefaultDefiniteFailure(target);
2088 return emitDefaultDefiniteFailure(target);
2091 FailureOr<LinalgOp> res =
promoteSubViews(rewriter, target, promotionOptions);
2093 return emitDefaultDefiniteFailure(target);
2106 auto payload = state.getPayloadOps(getTarget());
2110 if (target->getNumOperands() > 0)
2113 target->getNumRegions() > 0)
2115 <<
"expected target that is isolated from above";
2119 Operation *pattern = &getBodyRegion().front().front();
2122 if (getOperation()->isAncestor(target))
2127 replacements.push_back(replacement);
2129 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2133 void transform::ReplaceOp::getEffects(
2141 if (!getBodyRegion().hasOneBlock())
2142 return emitOpError() <<
"expected one block";
2143 if (std::distance(getBodyRegion().front().begin(),
2144 getBodyRegion().front().end()) != 1)
2145 return emitOpError() <<
"expected one operation in block";
2146 Operation *replacement = &getBodyRegion().front().front();
2149 <<
"expected replacement without operands";
2153 <<
"expect op that is isolated from above";
2171 target.createFlatListOfOperandDims(b, loc);
2172 AffineMap map = target.getShapesToLoopsMap();
2188 FailureOr<scf::SCFTilingResult> maybeTilingResult =
tileUsingSCF(
2189 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2190 if (failed(maybeTilingResult))
2191 return emitDefaultDefiniteFailure(target);
2193 if (target->getNumResults())
2194 rewriter.
replaceOp(target, maybeTilingResult->replacements);
2198 results.
reserve(maybeTilingResult->tiledOps.size());
2199 for (
Operation *tiled : maybeTilingResult->tiledOps)
2213 for (
Operation *target : state.getPayloadOps(getTarget())) {
2214 auto tilingOp = dyn_cast<TilingInterface>(*target);
2217 emitSilenceableError()
2218 <<
"expected the payload to implement TilingInterface";
2219 diag.attachNote(target->getLoc()) <<
"payload op";
2223 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2225 if (failed(generatedLoops))
2226 return emitDefaultDefiniteFailure(target);
2227 for (scf::ForOp &loop : *generatedLoops) {
2228 loops.push_back(loop.getOperation());
2232 results.
set(cast<OpResult>(getResult()), loops);
2241 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2247 FailureOr<Operation *> maybeResult =
2249 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2250 [&rewriter](
auto op) {
2253 if (failed(maybeResult))
2254 return emitDefaultSilenceableFailure(target);
2268 llvm::to_vector(state.getPayloadOps(getTarget()));
2270 bool isMultiwaySplit = getMultiway();
2272 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2274 <<
"requires exactly one target when "
2275 "multiway split is enabled (got "
2276 << llvm::range_size(payload) <<
")";
2281 if (!isMultiwaySplit)
2282 chunkSizes.reserve(payload.size());
2284 if (getDynamicChunkSizes()) {
2286 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2287 chunkSizes = llvm::to_vector(llvm::map_range(
2288 state.getPayloadOps(getDynamicChunkSizes()), [&](
Operation *op) {
2291 diag = emitSilenceableError()
2292 <<
"expected dynamic split point handle to point to a "
2293 "single-result index-typed op";
2294 diag.attachNote(op->getLoc()) <<
"dynamic split point";
2299 chunkSizes = llvm::to_vector(
2300 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2303 if (
diag.isSilenceableFailure())
2308 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2310 <<
"expected the dynamic split point handle to point to as "
2312 << chunkSizes.size() <<
") as the target handle ("
2313 << payload.size() <<
")";
2316 chunkSizes.resize(payload.size(),
2320 auto checkStructuredOpAndDimensions =
2323 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2324 diag.attachNote(loc) <<
"target op";
2328 if (getDimension() >= linalgOp.getNumLoops()) {
2329 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2330 <<
" does not exist in target op";
2331 diag.attachNote(loc) <<
"target op";
2337 auto checkFailureInSplitting =
2341 diag.attachNote(loc) <<
"target op";
2347 if (isMultiwaySplit) {
2351 TilingInterface head, tail;
2354 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2358 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2359 if (
diag.isSilenceableFailure())
2365 target = tail.getOperation();
2370 linalgOp = cast<LinalgOp>(target);
2375 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2376 getDimension(), chunkSize);
2380 checkFailureInSplitting(!head && !tail, loc);
2381 if (
diag.isDefiniteFailure())
2384 opList.push_back(head.getOperation());
2389 opList.push_back(tail.getOperation());
2390 results.
set(cast<OpResult>(getFirst()), opList);
2391 results.
set(cast<OpResult>(getSecond()), {});
2397 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2400 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2402 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2404 if (
diag.isSilenceableFailure())
2408 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2409 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2410 getDimension(), std::get<1>(pair));
2414 checkFailureInSplitting(!first.back() && !second.back(), loc);
2419 if (!second.back()) {
2420 noSecondPart = target;
2425 if (second.size() != first.size() && !second.empty()) {
2426 auto diag = emitSilenceableError()
2427 <<
"splitting does not produce the second part for a subset "
2430 <<
"expected splitting to produce the second part of all "
2431 "or none of the targets";
2433 <<
"first target with no second part";
2437 results.
set(cast<OpResult>(getFirst()), first);
2438 results.
set(cast<OpResult>(getSecond()), second);
2443 void SplitOp::getEffects(
2446 if (getDynamicChunkSizes())
2454 IntegerAttr staticChunkSizes;
2460 if (!dynamicPointParseResult.
has_value()) {
2461 int64_t staticChunkSizesValue;
2462 if (failed(parser.
parseInteger(staticChunkSizesValue)))
2475 if (dynamicPointParseResult.
has_value()) {
2476 Type ChunkSizesType;
2477 if (failed(*dynamicPointParseResult) || parser.
parseComma() ||
2489 SplitOp::getStaticChunkSizesAttrName(result.
name).getValue(),
2491 result.
addTypes({targetType, targetType});
2496 printer <<
" " << getTarget() <<
" after ";
2497 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
2498 if (staticChunkSize != ShapedType::kDynamic)
2499 printer << staticChunkSize;
2501 printer << getDynamicChunkSizes();
2504 {getStaticChunkSizesAttrName()});
2505 printer <<
" : " << getTarget().getType();
2506 if (staticChunkSize == ShapedType::kDynamic)
2507 printer <<
", " << getDynamicChunkSizes().getType();
2511 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2512 (getDynamicChunkSizes() ==
nullptr)) {
2513 return emitOpError() <<
"expects either a dynamic or a static split "
2514 "point to be provided";
2523 void transform::SplitReductionOp::build(
2525 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
2526 bool useScalingAlgorithm,
bool useAlloc) {
2532 SplitReductionOp::getInsertSplitDimensionAttrName(result.
name),
2534 if (innerParallel) {
2535 result.
addAttribute(SplitReductionOp::getInnerParallelAttrName(result.
name),
2538 if (useScalingAlgorithm) {
2540 SplitReductionOp::getUseScalingAlgorithmAttrName(result.
name),
2548 result.
addTypes({resultType, resultType, resultType, resultType});
2557 unsigned(getInsertSplitDimension()),
2558 bool(getInnerParallel())};
2561 FailureOr<SplitReductionResult> splitResult =
2562 (getUseScalingAlgorithm())
2565 if (failed(splitResult))
2566 return emitDefaultDefiniteFailure(target);
2568 results.
push_back(splitResult->initOrAlloc);
2570 results.
push_back(splitResult->splitLinalgOp);
2571 results.
push_back(splitResult->resultCombiningLinalgOp);
2579 void transform::TileReductionUsingForOp::build(
2590 build(builder, result,
2593 staticTileSizesAttr);
2602 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2606 return emitDefaultSilenceableFailure(target);
2607 for (
Value initValue : result->initialValues)
2609 for (
auto parallelTiledOp : result->parallelTiledOps)
2611 for (
auto mergeOp : result->mergeOps)
2613 results.
push_back(result->loops.front());
2621 void transform::TileReductionUsingForallOp::build(
2624 ArrayAttr mapping) {
2634 build(builder, result,
2637 staticNumThreadsAttr,
2638 staticTileSizesAttr,
2651 FailureOr<linalg::ForallReductionTilingResult> result =
2653 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2654 numThreads, tileSizes, getMapping());
2656 if (failed(result)) {
2657 auto diag = emitSilenceableError() <<
"could not tile reduction";
2658 diag.attachNote(target.getLoc()) <<
"target operation";
2661 for (
Value initValue : result->initialValues)
2663 for (
auto parallelTiledOp : result->parallelTiledOps)
2665 for (
auto mergeOp : result->mergeOps)
2681 llvm::to_vector(state.getPayloadOps(getTarget()));
2683 if (!llvm::hasSingleElement(targetOps)) {
2685 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
2690 auto linalgOp = dyn_cast<LinalgOp>(target);
2691 auto tileableOp = dyn_cast<TilingInterface>(target);
2696 OpBuilder builder(linalgOp.getContext());
2698 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
2699 if (linalgOp.hasDynamicShape()) {
2700 auto diag = emitSilenceableError()
2701 <<
"cannot compute parametric tile sizes for dynamically "
2702 "shaped payload op";
2703 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
2707 FailureOr<StaticContinuousTileSizeSpecification> spec =
2711 return emitSilenceableError()
2712 <<
"failed to compute multi-size tiling sizes";
2717 for (
auto &&[tileSize, tripCount] :
2718 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2719 chunkSizes.push_back(tileSize * tripCount);
2722 return llvm::map_to_vector(values, [&](int64_t value) ->
Attribute {
2726 transformResults.
setParams(cast<OpResult>(getTileSizes()),
2727 getI64AttrsFromI64(spec->tileSizes));
2728 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
2729 getI64AttrsFromI64(chunkSizes));
2737 unsigned dimension = getDimension();
2740 builder, tileableOp, dimension, targetSize,
true);
2742 return emitSilenceableError() <<
"could not generate tile size computation";
2754 for (
auto &&[tileSize, tripCount] :
2755 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2756 splitPoint = apply(s0 * s1, {tileSize, tripCount});
2757 chunkSizes.push_back(splitPoint);
2761 return llvm::map_to_vector(values, [&](
Value value) ->
Operation * {
2766 transformResults.
set(cast<OpResult>(getTileSizes()),
2767 getDefiningOps(spec->tileSizes));
2768 transformResults.
set(cast<OpResult>(getChunkSizes()),
2769 getDefiningOps(chunkSizes));
2777 return emitOpError() <<
"expects all results type to be the same";
2783 void transform::ContinuousTileSizesOp::getEffects(
2785 if (isa<TransformParamTypeInterface>(getTileSizes().
getType()))
2801 Type &tileSizesType,
2802 Type &chunkSizesType) {
2803 FunctionType funcType;
2805 if (failed(parser.
parseType<FunctionType>(funcType)))
2808 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
2809 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
2810 "argument and one result";
2812 targetType = funcType.getInput(0);
2813 tileSizesType = chunkSizesType = funcType.getResult(0);
2822 void transform::TileUsingForOp::build(
2827 return build(builder, result, loopTypes,
2831 interchange, scalableSizes);
2834 void transform::TileUsingForOp::build(
2838 build(builder, result, target,
2840 interchange, scalableSizes);
2843 void transform::TileUsingForOp::build(
2850 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2854 void transform::TileUsingForOp::build(
2866 unsigned numExpectedLoops =
2867 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2869 resultTypes.reserve(numExpectedLoops);
2870 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2871 "expected one loop type or as many as loops");
2872 if (loopTypes.size() == 1)
2873 resultTypes.append(numExpectedLoops, loopTypes[0]);
2875 llvm::append_range(resultTypes, loopTypes);
2877 if (scalableSizes.has_value())
2878 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2879 build(builder, result, target.
getType(),
2883 staticTileSizesAttr,
2885 expandedScalableSizes);
2890 return emitOpError(
"expected same number of sizes (")
2892 << getScalableSizes().size() <<
")";
2894 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
2895 if (getLoops().size() != numExpectedLoops)
2896 return emitOpError(
"expected number of loops to tile (")
2897 << numExpectedLoops <<
") to match number of `loops` results ("
2898 << getLoops().size() <<
")";
2909 llvm::to_vector(state.getPayloadOps(getTarget()));
2915 if (isa<ParamType>(transformValue.getType())) {
2916 dynamicSizeProducers.push_back({});
2918 paramSizes.push_back(
2919 llvm::to_vector(llvm::map_range(params, [](
Attribute attr) {
2920 return cast<IntegerAttr>(attr).getValue().getSExtValue();
2923 if (paramSizes.back().size() != targets.size()) {
2925 emitSilenceableError()
2926 <<
"expected as many parameter values ("
2927 << dynamicSizeProducers.back().size() <<
") as target ops ("
2928 << targets.size() <<
")";
2929 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
2935 paramSizes.push_back({});
2936 dynamicSizeProducers.push_back(
2937 llvm::to_vector(state.getPayloadOps(transformValue)));
2939 if (dynamicSizeProducers.back().size() != targets.size()) {
2941 emitSilenceableError()
2942 <<
"expected as many dynamic size-producing operations ("
2943 << dynamicSizeProducers.back().size() <<
") as target ops ("
2944 << targets.size() <<
")";
2945 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
2949 for (
Operation *op : dynamicSizeProducers.back()) {
2956 emitSilenceableError() <<
"expected sizes to be produced by ops "
2957 "with a single index-type result";
2958 diag.attachNote(op->
getLoc()) <<
"size producer op";
2959 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
2966 loops.resize(getLoops().size());
2967 auto scalableSizes = getScalableSizes();
2969 auto tilingInterface = dyn_cast<TilingInterface>(op);
2970 if (!tilingInterface) {
2972 emitSilenceableError()
2973 <<
"only ops implementing TilingInterface are supported";
2974 diag.attachNote(op->
getLoc()) <<
"target op";
2977 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
2979 emitSilenceableError()
2980 <<
"too many tiles provided, expected at most "
2981 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
2982 << tileSizes.size();
2983 diag.attachNote(op->
getLoc()) <<
"target op";
2988 if (tileSizes.empty()) {
2997 sizes.reserve(tileSizes.size());
2998 unsigned dynamicIdx = 0;
3001 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3002 if (scalableSizes[ofrIdx]) {
3003 auto val = b.
create<arith::ConstantIndexOp>(
3004 getLoc(), cast<IntegerAttr>(attr).getInt());
3010 sizes.push_back(attr);
3017 assert((dynamicSizes.empty() ^ params.empty()) &&
3018 "expected either dynamic sizes or parameters");
3019 if (!params.empty()) {
3022 sizes.push_back(dynamicSizes[index]->getResult(0));
3030 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3031 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3032 if (failed(maybeTilingResult))
3035 rewriter.
replaceOp(op, maybeTilingResult->replacements);
3037 tiled.append(maybeTilingResult->tiledOps);
3039 loops[en2.index()].push_back(en2.value());
3042 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3044 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3053 results.reserve(tileSizes.size());
3054 unsigned dynamicPos = 0;
3056 for (int64_t size : tileSizes) {
3057 if (size == ShapedType::kDynamic) {
3058 results.push_back(dynamic[dynamicPos++]);
3066 void transform::TileUsingForOp::getEffects(
3078 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3082 ArrayAttr mapping) {
3083 return build(builder, result,
3091 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3095 ArrayAttr mapping) {
3105 build(builder, result,
3106 TypeRange{operationType, operationType},
3113 staticTileSizesAttr,
3117 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3121 ArrayAttr mapping) {
3122 return build(builder, result, target,
3127 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3131 ArrayAttr mapping) {
3142 build(builder, result,
3143 TypeRange{operationType, operationType},
3149 staticNumThreadsAttr,
3156 TransformOpInterface transformOp,
Operation *target,
3161 auto tileableOp = dyn_cast<TilingInterface>(target);
3164 transformOp.emitSilenceableError()
3165 <<
"only TilingInterface ops are supported";
3166 diag.attachNote(target->
getLoc()) <<
"target op";
3170 FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
3171 if (!mixedNumThreads.empty()) {
3176 rewriter, tileableOp, mixedTileSizes, mapping);
3179 if (failed(maybeTilingResult))
3180 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3181 rewriter.
replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
3183 tilingResult = *maybeTilingResult;
3191 auto transformOp = cast<TransformOpInterface>(getOperation());
3200 getPackedNumThreads()
3202 state, transformOp, mixedNumThreads, getPackedNumThreads())
3204 state, transformOp, mixedNumThreads, getMixedNumThreads());
3208 status = getPackedTileSizes()
3210 state, transformOp, mixedTileSizes, getPackedTileSizes())
3212 state, transformOp, mixedTileSizes, getMixedTileSizes());
3216 for (
Operation *target : state.getPayloadOps(getTarget())) {
3219 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3220 getMapping(), tilingResult);
3221 if (!
diag.succeeded())
3223 tileOps.push_back(tilingResult.
tileOp);
3224 tiledOps.push_back(tilingResult.
tiledOp);
3227 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3228 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3233 void transform::TileUsingForallOp::getEffects(
3246 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3255 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3256 static_cast<int>(getPackedNumThreads() !=
Value());
3257 if (numThreadsSpec > 1)
3259 "num_threads and packed_num_threads are mutually exclusive");
3260 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3261 static_cast<int>(getPackedTileSizes() !=
Value());
3262 if (tileSizesSpec > 1)
3264 "tile_sizes and packed_tile_sizes are mutually exclusive");
3265 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3266 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
3267 "must be specified");
3275 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3277 bool vectorizePadding,
bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
3279 if (vectorizePadding) {
3281 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3285 if (vectorizeExtract) {
3287 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3291 if (flatten1DDepthwiseConv) {
3293 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3304 explicit VectorizationPattern(
MLIRContext *context,
3305 bool vectorizeExtract =
false,
3306 bool flattenConv =
false)
3308 vectorizeNDExtract(vectorizeExtract),
3309 flatten1DDepthwiseConv(flattenConv) {}
3310 LogicalResult matchAndRewrite(
Operation *op,
3312 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
3315 return vectorize(rewriter, linalgOp, {},
3316 {}, vectorizeNDExtract,
3317 flatten1DDepthwiseConv);
3323 bool vectorizeNDExtract =
false;
3327 bool flatten1DDepthwiseConv =
false;
3332 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3337 auto diag = this->emitOpError(
"requires isolated-from-above targets");
3338 diag.attachNote(target->
getLoc()) <<
"non-isolated target";
3344 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3345 getFlatten_1dDepthwiseConv());
3347 if (!getDisableTransferPermutationMapLoweringPatterns())
3350 if (!getDisableMultiReductionToContractPatterns())
3358 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3359 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3364 if (getVectorizePadding())
3371 return emitDefaultDefiniteFailure(target);
3385 auto targets = state.getPayloadOps(getTarget());
3386 if (std::empty(targets))
3388 auto transformOp = cast<TransformOpInterface>(getOperation());
3391 state, transformOp, getMixedVectorSizes(), vectorSizes);
3397 if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
3400 <<
"Unsupported Op, cannot vectorize";
3405 getVectorizeNdExtract().has_value()
3406 ? getVectorizeNdExtract().value()
3409 <<
"Attempted to vectorize, but failed";
3416 void transform::VectorizeOp::getEffects(
3425 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3429 if (getStaticVectorSizes().size() != getScalableSizes().size())
3430 return emitOpError(
"expected same number of vector sizes (")
3431 << getStaticVectorSizes().size() <<
") and scalable sizes ("
3432 << getScalableSizes().size() <<
")";
3441 transform::HoistRedundantVectorTransfersOp::applyToOne(
3458 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3477 auto maybeTransformed =
3480 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3483 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3486 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3489 .Case([&](linalg::Conv2DNchwFchwOp op) {
3495 if (failed(maybeTransformed))
3496 return emitDefaultSilenceableFailure(target);
3498 results.
push_back(maybeTransformed->first);
3500 results.
push_back(maybeTransformed->second);
3515 <<
"only elementwise flattening is supported";
3518 if (target.getNumLoops() <= 1) {
3525 std::iota(reassociation.begin(), reassociation.end(), 0);
3526 auto maybeFlattened =
3528 if (failed(maybeFlattened))
3530 <<
"attempted to flatten, but failed";
3531 results.
push_back(maybeFlattened->collapsedOp);
3532 rewriter.
replaceOp(target, maybeFlattened->results);
3545 auto maybeTransformed =
3547 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3550 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3556 if (failed(maybeTransformed))
3557 return emitDefaultSilenceableFailure(target);
3572 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3573 auto maybeTransformed =
3575 .Case([&](linalg::MatmulOp op) {
3578 .Case([&](linalg::BatchMatmulOp op) {
3581 .Default([&](
Operation *op) {
return failure(); });
3582 if (failed(maybeTransformed))
3592 template <
typename OpTy>
3596 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3597 tensor::ParallelInsertSliceOp>() &&
3600 if (
auto copySource =
3601 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3608 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3610 target->template getParentOfType<scf::InParallelOp>());
3613 Value extracted = rewriter.
create<tensor::ExtractSliceOp>(
3614 target.getLoc(), target.getDest(), target.getMixedOffsets(),
3615 target.getMixedSizes(), target.getMixedStrides());
3616 Value copied = rewriter
3617 .
create<linalg::CopyOp>(target.getLoc(),
3618 target.getSource(), extracted)
3623 target, copied, target.getDest(), target.getMixedOffsets(),
3624 target.getMixedSizes(), target.getMixedStrides());
3626 results.
push_back(copied.getDefiningOp());
3636 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3637 return doit(rewriter, target, results, state);
3638 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3639 return doit(rewriter, target, results, state);
3642 emitSilenceableError()
3643 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3644 diag.attachNote(targetOp->
getLoc()) <<
"target op";
3657 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3659 emitSilenceableError()
3660 <<
"only linalg.copy and tensor.pad target ops are supported";
3661 diag.attachNote(target->
getLoc()) <<
"target op";
3664 assert(target->
getNumResults() == 1 &&
"expected single result");
3666 if (!resultShapedType.hasStaticShape()) {
3668 emitSilenceableError()
3669 <<
"only statically sized ops of rank <= 3 are supported";
3670 diag.attachNote(target->
getLoc()) <<
"target op";
3675 int64_t desiredBitAlignment = getDesiredBitAlignment();
3676 int64_t eltBitwidth =
3677 resultShapedType.getElementType().getIntOrFloatBitWidth();
3678 if (desiredBitAlignment % eltBitwidth != 0) {
3679 desiredBitAlignment = eltBitwidth;
3684 getTotalNumThreads(),
3685 desiredBitAlignment,
3686 resultShapedType.getShape(),
3689 resultShapedType.getElementType().getIntOrFloatBitWidth());
3690 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3692 emitSilenceableError()
3693 <<
"too few threads to map copy op to threads on the most minor "
3694 "dimension, given alignment and vector size constraints, try "
3695 "smaller tile size of mapping to more threads";
3696 diag.attachNote(target->
getLoc()) <<
"target op";
3712 if (!
diag.succeeded())
3729 FailureOr<Operation *> maybeTransformed = failure();
3731 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3736 .Default([&](
Operation *op) {
return false; });
3739 return emitSilenceableError()
3740 <<
"this operation is not supported to convert to Winograd Conv2D";
3743 if (supported && failed(maybeTransformed)) {
3744 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
3751 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3753 #define GET_OP_CLASSES
3754 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static FailureOr< ForallTilingResult > tileToForallOpImpl(RewriterBase &b, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayRef< OpFoldResult >> nominalTileSizes, std::optional< ArrayAttr > mapping, bool omitTileOffsetBoundsCheck)
Rewrite a TilingInterface op to a tiled scf.forall.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
This class represents a saved insertion point.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
FailureOr< ForallTilingResult > tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, std::optional< ArrayAttr > mapping)
Same as tileToForallOp, but calculate the number of threads required using the given tileSizes.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< ForallTilingResult > tileToForallOp(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayAttr > mapping)
void hoistRedundantVectorTransfers(Operation *root)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< GenericOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
uint64_t getM(LevelType lt)
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
@ MaterializeInDestination
Transformation to drop unit-extent dimensions from linalg.generic operations.
Vectorization pattern for memref::CopyOp.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Rewrite a TilingInterface op to a tiled scf.forall, applying tiling by numThreads.
Match and rewrite for the pattern:
Match and rewrite for the pattern:
@ BufferizationMaterializeInDestination
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > ts)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.