44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/ScopeExit.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/Debug.h"
48 #include <type_traits>
54 #define DEBUG_TYPE "linalg-transforms"
55 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56 #define DBGSNL() (llvm::dbgs() << "\n")
57 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
64 template <
typename PatternTy,
typename... Args>
67 using OpTy =
typename llvm::function_traits<
68 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
69 auto op = dyn_cast<OpTy>(operation);
74 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
79 explicit TrivialPatternRewriter(
MLIRContext *context)
82 TrivialPatternRewriter rewriter(operation->
getContext());
83 rewriter.setInsertionPoint(operation);
84 auto result = pattern.returningMatchAndRewrite(op, rewriter);
87 return cast<LinalgOp>(result->getOperation());
98 if (!isa<IntegerAttr>(ofr.get<
Attribute>()))
99 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
100 result.push_back(ofr);
105 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
107 if (params.size() != 1)
108 return transformOp.emitDefiniteFailure()
109 <<
"requires exactly one parameter associated";
110 result.push_back(params[0]);
114 auto payloadOps = state.getPayloadOps(transformValue);
115 if (!llvm::hasSingleElement(payloadOps)) {
117 transformOp.emitSilenceableError()
118 <<
"handle must be mapped to exactly one payload op";
120 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
127 transformOp.emitSilenceableError()
128 <<
"payload op must have exactly 1 index result";
148 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
150 for (
auto param : params) {
151 if (!isa<IntegerAttr>(param))
152 return transformOp.emitDefiniteFailure()
153 <<
"expected the parameter to be associated with an integer "
155 result.push_back(param);
160 for (
Operation *op : state.getPayloadOps(packedHandle)) {
161 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
163 transformOp.emitSilenceableError()
164 <<
"payload op must have exactly 1 index result";
165 diag.attachNote(op->getLoc())
166 <<
"has " << op->getNumResults() <<
" results";
169 result.push_back(op->getResult(0));
183 if (isa<Attribute>(paramOrHandle)) {
185 cast<IntegerAttr>(paramOrHandle.get<
Attribute>()).getInt());
187 }
else if (isa<ParamType>(paramOrHandle.get<
Value>().
getType())) {
189 if (params.size() != 1)
190 return transformOp.emitSilenceableError() <<
"expected a single param";
192 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
197 if (!isa<TransformHandleTypeInterface>(handle.
getType()))
198 return transformOp.emitSilenceableError() <<
"unexpected value handle";
199 auto payload = state.getPayloadOps(handle);
200 if (!llvm::hasSingleElement(payload))
201 return transformOp.emitSilenceableError()
202 <<
"requires param or handle that is mapped to 1 payload op";
204 Operation *paramOrHandlePayloadOp = *payload.begin();
207 return transformOp.emitSilenceableError()
208 <<
"requires param or handle to be result of op with 1 index "
214 return transformOp.emitSilenceableError()
215 <<
"requires param or handle to be the result of a constant like "
218 reified.push_back(attr.getInt());
227 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
232 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
237 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
242 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
248 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
251 options.rankReductionStrategy =
256 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
261 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
266 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
276 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
281 resultTypes.push_back(b.
getType<transform::AnyValueType>());
282 resultTypes.push_back(b.
getType<transform::AnyOpType>());
283 return build(b, result,
289 void transform::BufferizeToAllocationOp::build(
OpBuilder &b,
292 int64_t memorySpace) {
294 resultTypes.push_back(b.
getType<transform::AnyValueType>());
295 resultTypes.push_back(b.
getType<transform::AnyOpType>());
296 return build(b, result,
312 void notifyOperationInserted(
Operation *op,
314 ForwardingListener::notifyOperationInserted(op, previous);
316 if (previous.
isSet())
318 auto inserted = newOps.insert(op);
320 assert(inserted.second &&
"expected newly created op");
323 void notifyOperationErased(
Operation *op)
override {
324 ForwardingListener::notifyOperationErased(op);
338 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
339 NewOpsListener newOpsListener(previousListener);
343 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
346 }
else if (getMemcpyOp() ==
"memref.copy") {
349 }
else if (getMemcpyOp() ==
"linalg.copy") {
353 llvm_unreachable(
"invalid memcpy op");
355 if (getAllocOp() ==
"memref.alloc") {
358 }
else if (getAllocOp() ==
"memref.alloca") {
362 llvm_unreachable(
"invalid alloc op");
364 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
365 options.emitDealloc = getEmitDealloc();
369 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
371 for (
Operation *op : state.getPayloadOps(getTarget())) {
376 <<
"failed to bufferize operation";
377 diag.attachNote(op->
getLoc()) <<
"target payload op";
380 allocatedBuffers.push_back(buffer);
384 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
385 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
389 void transform::BufferizeToAllocationOp::getEffects(
391 if (getBufferizeDestinationOnly()) {
403 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
404 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
405 return emitOpError() <<
"unsupported memcpy op";
406 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
407 return emitOpError() <<
"unsupported alloc op";
420 #define DOWNSCALE(trans) \
422 FailureOr<LinalgOp> res = tryApply<trans>(target); \
423 if (succeeded(res)) { \
424 results.push_back(*res); \
425 return DiagnosedSilenceableFailure::success(); \
429 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
430 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
443 #undef DOWNSCALE_NORMAL
444 #undef DOWNSCALE_CALL
446 return emitDefaultSilenceableFailure(target);
460 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
461 if (!decomposableOp) {
463 "payload is not a decomposable op"));
464 return emitDefaultSilenceableFailure(target);
467 FailureOr<SmallVector<Value>> maybeNewResults =
468 decomposableOp.decomposeOperation(rewriter);
469 if (failed(maybeNewResults))
470 return emitDefaultSilenceableFailure(target);
472 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
473 for (
Value val : *maybeNewResults) {
474 Operation *definition = val.getDefiningOp();
485 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
492 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
496 options.allowReturnAllocsFromLoops =
true;
498 for (
Operation *target : state.getPayloadOps(getTarget())) {
502 <<
"failed to analyze op";
504 rewriter, target, state)))
506 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
517 template <
typename Range>
521 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
527 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
528 if (!tilingInterfaceOp)
529 return transformOp->
emitError(
"only TilingInterface ops are supported");
532 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
533 applyFn(tilingInterfaceOp);
534 if (failed(tiledResults))
539 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
540 for (
Operation *toReplace : opsToReplace) {
541 for (
OpResult res : toReplace->getResults())
542 if (
auto replacement = tiledResults->replacements.lookup(res))
544 if (toReplace->use_empty()) {
550 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
551 assert(tiledResults->loops.size() == numLoops &&
552 "Mismatched number of loops, tile and fuse transform should have "
554 for (
unsigned int i = 0; i < numLoops; ++i)
555 loopOps[i].push_back(tiledResults->loops[i]);
558 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
559 for (
unsigned int i = 0; i < numLoops; ++i)
560 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
570 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
572 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
578 tilingOptions = tilingOptions.
setTileSizes(tileSizesOfr);
582 if (getApplyCleanup()) {
585 tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
591 rewriter, getOperation(), state.getPayloadOps(getTarget()),
592 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
593 [&](TilingInterface tilingInterfaceOp)
594 -> FailureOr<scf::SCFTileAndFuseResult> {
595 return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
604 extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
605 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
606 if (!std::is_permutation(sequence.begin(), sequence.end(),
607 permutation.begin(), permutation.end())) {
608 return emitOpError() <<
"expects interchange to be a permutation, found "
609 << getTileInterchange();
613 extractFromIntegerArrayAttr<int64_t>(getTileSizes());
614 size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
615 if (numExpectedLoops != getNumResults() - 1)
616 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
625 void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
628 Value containingOp) {
631 result.
addTypes({resultType, resultType});
647 (domInfo.
dominates(containingOp, user))) {
648 dominatedUsers.insert(user);
651 if (dominatedUsers.empty())
655 auto forallOp = cast<scf::ForallOp>(containingOp);
661 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
666 newOuts.push_back(outputs[resultNumber]);
669 auto newforallOp = rewriter.
create<scf::ForallOp>(
670 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
671 forallOp.getMixedStep(), newOuts, forallOp.getMapping());
673 newforallOp.getRegion().takeBody(forallOp.getRegion());
678 newforallOp.getBody()->addArgument(newOuts.back().getType(),
679 newOuts.back().getLoc());
680 auto bbArgs = newforallOp.getBody()->getArguments();
688 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
690 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; }));
691 Operation *firstYieldOp = yieldingOps.front();
694 Value dst = newforallOp.getRegionIterArgs().back();
696 rewriter.
create<tensor::ParallelInsertSliceOp>(firstYieldOp->
getLoc(), src,
697 dst, offsets, sizes, strides);
701 newforallOp->getResult(result.index()));
704 newforallOp->getResults().back(),
707 return dominatedUsers.contains(user);
718 static std::tuple<SmallVector<Operation *>,
Operation *>
721 LLVM_DEBUG(
DBGS() <<
"Try to fuse a direct extract use\n");
722 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
723 if (!tileableProducer) {
725 <<
"producer is not a TileableInterface: " << *producerOp;
732 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
733 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
734 return sliceOp && containingOp->isProperAncestor(sliceOp);
738 if (it == tileableProducer->getUsers().end()) {
739 diag.attachNote(tileableProducer->getLoc())
740 <<
"could not find fusion opportunity for: " << *tileableProducer;
743 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
750 int64_t resultNumber =
751 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
752 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
757 FailureOr<TilingResult> tileAndFuseResult =
758 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
761 if (failed(tileAndFuseResult)) {
762 diag.attachNote(tileableProducer->getLoc())
763 <<
"failed to tile producer op: " << *tileableProducer;
768 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
769 LLVM_DEBUG(
DBGS() <<
"tiledProducer: " << *tiledOp <<
"\n");
774 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
775 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
776 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
777 if (failed(maybeRankReduced)) {
779 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
780 << tileAndFuseResult->tiledValues[0]
781 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
784 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
788 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
789 resultNumber, offsets, sizes);
791 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
804 LLVM_DEBUG(
DBGS() <<
"Try to fuse an extract use through block argument\n");
806 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
807 if (!tileableProducer) {
809 <<
"producer is not a TileableInterface: " << *producerOp;
814 scf::ForallOp forallOp;
815 auto itProducerUses =
816 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
817 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
821 if (!forallOp || forallOp != containingOp) {
822 diag.attachNote(tileableProducer->getLoc())
823 <<
"could not find a use by the containing op: " << *tileableProducer;
838 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
839 return sliceOp && containingOp->isProperAncestor(sliceOp);
843 if (itBBArgUsers == bbArg.
getUsers().end()) {
845 <<
"could not find fusion opportunity for bbArg: " << bbArg;
848 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
856 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
857 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
862 rewriter, tileableProducer->getLoc(), tileableProducer,
863 destinationTensors))) {
864 diag.attachNote(tileableProducer->getLoc())
865 <<
"failed to get destination tensors for: " << *tileableProducer;
870 bvm.
map(destinationTensors[resultNumber], bbArg);
871 auto tileableProducerClone =
872 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
874 llvm::make_scope_exit([&]() { rewriter.
eraseOp(tileableProducerClone); });
877 FailureOr<TilingResult> tileAndFuseResult =
878 tileableProducerClone.generateResultTileValue(
879 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
880 sliceOpToTile.getMixedSizes());
881 if (failed(tileAndFuseResult)) {
882 diag.attachNote(tileableProducer->getLoc())
883 <<
"failed to tile producer op: " << *tileableProducer;
888 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
889 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
890 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
891 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
892 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
897 destinationTensors.front());
900 return tileAndFuseResult->tiledOps;
906 LLVM_DEBUG(
DBGS() <<
"Try to fuse an use by cloning\n");
911 for (
OpOperand &use : result.getUses()) {
913 uses.push_back(&use);
918 if (containingOp == use.getOwner()) {
920 <<
"producer op use by containing op cannot be fused by cloning";
928 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
937 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
938 "Parallel insert slice is not a valid clone destination");
939 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
940 LLVM_DEBUG(
DBGS() <<
"resultNumber: " << resultNumber <<
"\n");
944 fusedOp = rewriter.
clone(*producerOp);
946 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
951 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
961 auto producerOps = state.getPayloadOps(getProducerOp());
962 auto containingOps = state.getPayloadOps(getContainingOp());
963 if (!llvm::hasSingleElement(containingOps)) {
965 <<
"requires exactly one containing_op handle (got "
966 << llvm::range_size(containingOps) <<
")";
968 Operation *containingOp = *containingOps.begin();
971 if (std::empty(producerOps)) {
973 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
981 auto getNextProducer = [&]() -> FailureOr<Operation *> {
982 for (
const auto &it :
enumerate(remainingProducers)) {
985 int64_t numUsesInContainingOp =
987 return containingOp->isAncestor(op);
992 if (numUsesInContainingOp > 0) {
993 if (numUsesInContainingOp == 1)
994 remainingProducers.erase(remainingProducers.begin() + it.index());
1001 while (!remainingProducers.empty()) {
1002 auto nextProducer = getNextProducer();
1003 if (failed(nextProducer)) {
1005 <<
"could not find next producer to fuse into container";
1006 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
1014 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
1021 auto [tiledOps, newContainingOp] =
1023 if (!tiledOps.empty()) {
1024 LLVM_DEBUG(
DBGS() <<
"\nFused a direct extract use\n" << *containingOp);
1025 fusedOps.append(tiledOps);
1026 if (newContainingOp) {
1034 LogicalResult replacementStatus =
1037 (void)replacementStatus;
1038 assert(succeeded(replacementStatus) &&
1039 "unable to update transform state mapping");
1040 rewriter.
eraseOp(containingOp);
1041 containingOp = newContainingOp;
1048 rewriter,
diag, producerOp, containingOp);
1049 if (!tiledContainingOpOperand.empty()) {
1050 LLVM_DEBUG(
DBGS() <<
"\nFused an extract use through block argument\n"
1052 fusedOps.append(tiledContainingOpOperand);
1059 LLVM_DEBUG(
DBGS() <<
"\nFused an use by cloning\n" << *containingOp);
1060 fusedOps.push_back(cloned);
1066 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1067 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1071 void transform::FuseIntoContainingOp::getEffects(
1089 if (isa<GenericOp>(target)) {
1095 if (succeeded(
generic)) {
1096 results.
push_back(generic->getOperation());
1099 return emitDefaultSilenceableFailure(target);
1112 if (!isa<GenericOp>(target)) {
1117 FailureOr<LinalgOp> named =
1119 if (succeeded(named)) {
1120 results.
push_back(named->getOperation());
1123 return emitDefaultSilenceableFailure(target);
1137 if (interchangeVector.empty()) {
1142 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1143 if (interchangeVector.size() != numLoops) {
1144 return emitSilenceableError()
1145 << getIteratorInterchangeAttrName() <<
" has length ("
1146 << interchangeVector.size()
1147 <<
") different from the number of loops in the target operation ("
1160 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1161 if (!std::is_permutation(sequence.begin(), sequence.end(),
1162 permutation.begin(), permutation.end())) {
1163 return emitOpError()
1164 <<
"expects iterator_interchange to be a permutation, found "
1165 << getIteratorInterchange();
1179 FailureOr<LowerPackResult> res =
lowerPack(rewriter, target);
1182 <<
"cannot lower to pad + expand + transpose";
1185 transformResults.
push_back(res->expandShapeOp);
1186 transformResults.
push_back(res->transposeOp);
1199 FailureOr<LowerUnPackOpResult> res =
lowerUnPack(rewriter, target);
1202 emitSilenceableError()
1203 <<
"cannot lower to transpose + collapse + extract";
1204 diag.attachNote(target->getLoc()) <<
"target payload op";
1207 transformResults.
push_back(res->emptyOp);
1208 transformResults.
push_back(res->transposeOp);
1209 transformResults.
push_back(res->collapseShapeOp);
1210 transformResults.
push_back(res->extractSliceOp);
1240 if (getOps().has_value())
1241 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1242 getOps()->getAsValueRange<StringAttr>().end());
1244 auto payloadOps = state.getPayloadOps(getTarget());
1245 if (!llvm::hasSingleElement(payloadOps)) {
1250 bool incorrectNumOperandTypes =
false;
1257 if (getInterface().has_value()) {
1258 auto iface = getInterface().value();
1259 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1262 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1263 !isa<TilingInterface>(op))
1265 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1266 !isa<LoopLikeOpInterface>(op))
1271 if (getOpAttrs().has_value()) {
1272 DictionaryAttr opAttrs = getOpAttrs().value();
1274 if (attr.getName() == getInterfaceAttrName() ||
1275 attr.getName() == getOpsAttrName())
1277 if (!op->
hasAttr(attr.getName()))
1279 if (op->
getAttr(attr.getName()) != attr.getValue())
1284 if (getFilterResultType().has_value()) {
1285 Type t = getFilterResultType().value();
1290 if (getFilterOperandTypes().has_value()) {
1291 mlir::ArrayAttr types = getFilterOperandTypes().value();
1294 if (types.size() == 1) {
1297 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1298 Type t = cast<::mlir::Type>(typeattr.getValue());
1300 [&](
Type operandType) { return operandType == t; }))
1305 if (types.size() != operandTypes.size()) {
1306 incorrectNumOperandTypes =
true;
1310 for (
auto [attr, operandType] :
1311 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1312 auto typeattr = cast<mlir::TypeAttr>(attr);
1313 Type type = cast<::mlir::Type>(typeattr.getValue());
1315 if (type != operandType)
1326 (*payloadOps.begin())->
walk(matchFun);
1327 if (incorrectNumOperandTypes)
1329 "type, then it must contain as much types as "
1330 "the number of operands in the target ops");
1331 results.
set(cast<OpResult>(getResult()), res);
1346 Type &targetType,
Type &lowSizeType,
1348 Type &splitPointType) {
1349 FunctionType funcType;
1351 if (failed(parser.
parseType<FunctionType>(funcType)))
1354 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1355 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1356 "argument and one result";
1358 targetType = funcType.getInput(0);
1359 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1367 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1368 if (target.hasDynamicShape()) {
1369 auto diag = emitSilenceableError()
1370 <<
"cannot compute parametric tile sizes for dynamically "
1371 "shaped payload op";
1372 diag.attachNote(target->getLoc()) <<
"payload op";
1377 target, getDimension(), getTargetSize(), getDivisor());
1379 return emitSilenceableError()
1380 <<
"failed to compute multi-size tiling sizes";
1383 Builder builder(target.getContext());
1384 results.
assign(llvm::map_range(
1386 spec->lowTileSize * spec->lowTripCount}),
1387 [&builder,
this](int64_t value) {
1399 builder, target, getDimension(), targetSize, divisor);
1401 return emitSilenceableError() <<
"could not generate tile size computation";
1408 {spec->lowTileSize, spec->lowTripCount});
1409 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1410 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1411 assert(lowTileSize && highTileSize && splitPoint &&
1412 "tile sizes are not produced by operations");
1420 void transform::MultiTileSizesOp::getEffects(
1424 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1433 return emitOpError() <<
"expects all results type to be the same";
1453 builder.
getContext(), GenericOp::getOperationName());
1454 build(builder, result,
1463 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1470 auto targetOps = state.getPayloadOps(getTarget());
1472 if (std::empty(targetOps)) {
1473 transformResults.
set(cast<OpResult>(getPackedOp()),
1478 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1479 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1480 return emitSilenceableError()
1481 <<
"requires target to map to exactly 1 LinalgOp (got "
1482 << llvm::range_size(targetOps) <<
")";
1485 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1486 return emitSilenceableError()
1487 <<
"requires number of packed sizes match the number of loops ("
1488 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1495 state, *
this, packedSizes, getMixedPackedSizes());
1498 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1499 if (failed(maybeResult))
1502 transformResults.
set(cast<OpResult>(getPackedOp()),
1503 {maybeResult->packedLinalgOp.getOperation()});
1507 void transform::PackOp::getEffects(
1521 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1522 <<
" is not a valid permutation";
1525 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1526 for (
auto [s, nmo] :
1527 llvm::zip_equal(getMixedMatmulPackedSizes(),
1528 getMatmulPaddedSizesNextMultipleOf())) {
1531 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1532 return emitOpError() <<
"at most one of the packed_size and the "
1533 "padded_sizes_next_multiple_of can be nonzero "
1534 "for the matmul strategy";
1546 for (
Operation *op : state.getPayloadOps(getTarget())) {
1547 auto linalgOp = dyn_cast<LinalgOp>(op);
1558 getMixedMatmulPackedSizes(),
1560 getMatmulPaddedSizesNextMultipleOf(),
1561 getMatmulInnerDimsOrder());
1562 if (succeeded(packResult)) {
1563 results.push_back(packResult->packedLinalgOp);
1566 results.push_back(linalgOp);
1568 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1574 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1578 void transform::PackGreedilyOp::getEffects(
1592 return emitOpError() << getInnerPermAttrName()
1593 <<
" is not a valid permutation";
1596 return emitOpError() << getOuterPermAttrName()
1597 <<
" is not a valid permutation";
1599 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1600 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1601 <<
" or " << getOuterPermAttrName()
1602 <<
" must be specified";
1608 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1618 template <
typename RelayoutOpTy>
1621 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1623 llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
1624 "applies to only pack or unpack operations");
1625 if (!op || permutation.empty())
1627 size_t innerRank = op.getInnerDimsPos().size();
1628 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1632 if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
1633 return permutation.size() == op.getSourceRank() &&
1636 return permutation.size() == op.getDestRank() &&
1644 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1645 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1647 if (std::empty(packOrUnpackOps)) {
1648 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1649 transformResults.
set(cast<OpResult>(getPackOp()), {});
1650 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1656 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1657 !llvm::hasSingleElement(linalgOps)) {
1658 return emitSilenceableError()
1659 <<
"requires target to map to exactly 1 "
1660 "packing op and 1 packed op ("
1661 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
1662 << llvm::range_size(linalgOps) <<
")";
1666 auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
1667 auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
1668 if ((!packOp && !unPackOp)) {
1669 return emitSilenceableError() <<
"requires target to map to a "
1670 "tensor.pack or tensor.unpack";
1672 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1673 if (!linalgOpTarget)
1674 return emitSilenceableError() <<
"requires a LinalgOp target";
1678 if (packOp && packOp.getResult().hasOneUse())
1679 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1681 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1682 if (linalgOp != linalgOpTarget) {
1684 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
1685 : StringLiteral{
"not produced by the LinalgOp target"};
1686 return emitSilenceableError() << errorMsg;
1692 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
1693 OpOperand *packUse = linalgOp.getDpsInitOperand(
1694 cast<OpResult>(unPackOp.getSource()).getResultNumber());
1696 if (!packOp || !packOp.getResult().hasOneUse())
1697 return emitSilenceableError() <<
"could not find matching pack op";
1701 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1703 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1704 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1705 ? StringLiteral{
"invalid outer_perm"}
1706 : StringLiteral{
"invalid inner_perm"};
1710 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1711 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
1717 assert(packOp && linalgOp &&
"unexpected null op");
1721 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1723 assert(succeeded(res) &&
"unexpected packTranspose failure");
1726 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1727 transformResults.
set(cast<OpResult>(getPackedOp()),
1728 {res->transposedLinalgOp});
1730 transformResults.
set(cast<OpResult>(getUnPackOp()),
1731 {res->transposedUnPackOp});
1733 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1748 StringRef copyBackOp) {
1758 (padToMultipleOf.empty()
1760 : b.getDenseI64ArrayAttr(padToMultipleOf)),
1761 b.getI64ArrayAttr(nofoldFlags),
1762 b.getArrayAttr(transposePaddings),
1763 b.getStringAttr(copyBackOp));
1771 StringRef copyBackOp) {
1776 staticPadToMultipleOf);
1783 dynamicPadToMultipleOf,
1784 staticPadToMultipleOf,
1790 void PadOp::getEffects(
1800 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1807 auto transformOp = cast<TransformOpInterface>(getOperation());
1810 for (
Operation *target : state.getPayloadOps(getTarget())) {
1811 auto linalgTarget = dyn_cast<LinalgOp>(target);
1812 if (!linalgTarget) {
1813 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
1814 diag.attachNote(target->
getLoc()) <<
"target op";
1820 for (int64_t packPadding :
1821 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1822 nofoldFlags.push_back(
static_cast<bool>(packPadding));
1826 for (
auto const &it :
1827 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1828 auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1830 emitOpError(
"expects padding values to be typed attributes");
1835 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
1839 if (!parsedAttr || parsedAttr.getType() != elementType) {
1840 auto diag = this->emitOpError(
"expects a padding that parses to ")
1841 << elementType <<
", got " << std::get<0>(it);
1842 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
1845 paddingValues.push_back(parsedAttr);
1849 if (attr.getType() != elementType) {
1850 auto diag = this->emitOpError(
"expects a padding value of type ")
1851 << elementType <<
", got " << attr;
1852 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
1855 paddingValues.push_back(attr);
1860 for (
Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1861 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1862 cast<ArrayAttr>(transposeVector)));
1867 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1871 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
1874 if (padToMultipleOf.empty())
1878 options.padToMultipleOf = padToMultipleOf;
1879 options.paddingValues = paddingValues;
1880 options.nofoldFlags = nofoldFlags;
1881 if (getCopyBackOp() ==
1882 bufferization::MaterializeInDestinationOp::getOperationName()) {
1885 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1887 }
else if (getCopyBackOp() == kCopyOpNone) {
1890 llvm_unreachable(
"unsupported copy_back op");
1896 replacements, newPadOps))) {
1897 auto diag = emitSilenceableError() <<
"failed to pad op";
1898 diag.attachNote(target->
getLoc()) <<
"target op";
1907 rewriter.
replaceOp(linalgTarget, replacements);
1908 paddedOps.push_back(paddedOp);
1909 padOps.append(newPadOps.begin(), newPadOps.end());
1911 for (
Value v : replacements) {
1912 Operation *copyBackOp = v.getDefiningOp();
1913 if (!llvm::is_contained(copyBackOps, copyBackOp))
1914 copyBackOps.push_back(copyBackOp);
1919 results.
set(cast<OpResult>(getPadded()), paddedOps);
1920 results.
set(cast<OpResult>(getPad()), padOps);
1921 results.
set(cast<OpResult>(getCopy()), copyBackOps);
1927 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
1928 if (any_of(nofoldFlags, [](int64_t packPadding) {
1929 return packPadding != 0 && packPadding != 1;
1931 return emitOpError()
1932 <<
"expects nofold_flags to contain booleans (0/1), found "
1933 << getNofoldFlags();
1937 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1938 if (any_of(paddingDimensions,
1939 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
1940 return emitOpError() <<
"expects padding_dimensions to contain positive "
1942 << getPaddingDimensions();
1944 if (!getMixedPadToMultipleOf().empty()) {
1945 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
1946 return emitOpError() <<
"expects as many multiples as padding_dimensions";
1949 ArrayAttr transposes = getTransposePaddings();
1952 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
1953 if (!std::is_permutation(sequence.begin(), sequence.end(),
1955 return emitOpError()
1956 <<
"expects transpose_paddings to be a permutation, found "
1960 if (getCopyBackOp() !=
1961 bufferization::MaterializeInDestinationOp::getOperationName() &&
1962 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1963 getCopyBackOp() != kCopyOpNone)
1964 return emitOpError() <<
"invalid copy_back_op";
1976 auto targetOps = state.getPayloadOps(getTarget());
1977 auto loopOps = state.getPayloadOps(getLoop());
1978 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1980 <<
"requires exactly one target and one loop handle (got "
1981 << llvm::range_size(targetOps) <<
" and "
1982 << llvm::range_size(loopOps) <<
")";
1985 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1986 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
1987 if (!padOp || !loopOp)
1990 FailureOr<linalg::detail::PackingResult> result =
1996 if (result->clonedLoopIvs.empty()) {
1997 transformResults.
set(cast<OpResult>(getPackingLoop()),
1998 {result->hoistedPadOp.getOperation()});
2001 auto outerPackedLoop =
2003 transformResults.
set(cast<OpResult>(getPackingLoop()),
2004 {outerPackedLoop.getOperation()});
2010 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2011 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
2013 return emitOpError() <<
"expects transpose to be a permutation, found "
2019 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2029 tensor::PadOp target,
2032 tensor::PadOp hoistedPadOp;
2034 FailureOr<Value> result =
2036 hoistedPadOp, transposeOps);
2037 if (succeeded(result)) {
2047 return emitDefaultSilenceableFailure(target);
2052 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0,
transpose.size()));
2053 if (!std::is_permutation(sequence.begin(), sequence.end(),
transpose.begin(),
2055 return emitOpError() <<
"expects transpose to be a permutation, found "
2071 if (!getOperandsToPromote().empty())
2073 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2074 if (getUseFullTilesByDefault())
2076 getUseFullTilesByDefault());
2078 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2079 if (!getUseFullTileBuffers().empty())
2081 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2082 if (getAlignment().has_value())
2083 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2084 if (getMemorySpace().has_value())
2085 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2087 if (getMapping().has_value()) {
2089 auto mapping = *getMapping();
2090 if (mapping.size() > 1)
2091 return emitDefaultDefiniteFailure(target);
2093 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2095 if (addressSpace.getAddressSpace() ==
2096 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2103 }
else if (addressSpace.getAddressSpace() ==
2104 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2112 return emitDefaultDefiniteFailure(target);
2117 return emitDefaultDefiniteFailure(target);
2120 FailureOr<LinalgOp> res =
promoteSubViews(rewriter, target, promotionOptions);
2122 return emitDefaultDefiniteFailure(target);
2135 auto payload = state.getPayloadOps(getTarget());
2139 if (target->getNumOperands() > 0)
2142 target->getNumRegions() > 0)
2144 <<
"expected target that is isolated from above";
2148 Operation *pattern = &getBodyRegion().front().front();
2151 if (getOperation()->isAncestor(target))
2156 replacements.push_back(replacement);
2158 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2162 void transform::ReplaceOp::getEffects(
2170 if (!getBodyRegion().hasOneBlock())
2171 return emitOpError() <<
"expected one block";
2172 if (std::distance(getBodyRegion().front().begin(),
2173 getBodyRegion().front().end()) != 1)
2174 return emitOpError() <<
"expected one operation in block";
2175 Operation *replacement = &getBodyRegion().front().front();
2178 <<
"expected replacement without operands";
2182 <<
"expect op that is isolated from above";
2200 target.createFlatListOfOperandDims(b, loc);
2201 AffineMap map = target.getShapesToLoopsMap();
2217 FailureOr<scf::SCFTilingResult> maybeTilingResult =
tileUsingSCF(
2218 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2219 if (failed(maybeTilingResult))
2220 return emitDefaultDefiniteFailure(target);
2222 if (target->getNumResults())
2223 rewriter.
replaceOp(target, maybeTilingResult->replacements);
2227 results.
reserve(maybeTilingResult->tiledOps.size());
2228 for (
Operation *tiled : maybeTilingResult->tiledOps)
2242 for (
Operation *target : state.getPayloadOps(getTarget())) {
2243 auto tilingOp = dyn_cast<TilingInterface>(*target);
2246 emitSilenceableError()
2247 <<
"expected the payload to implement TilingInterface";
2248 diag.attachNote(target->getLoc()) <<
"payload op";
2252 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2254 if (failed(generatedLoops))
2255 return emitDefaultDefiniteFailure(target);
2256 for (scf::ForOp &loop : *generatedLoops) {
2257 loops.push_back(loop.getOperation());
2261 results.
set(cast<OpResult>(getResult()), loops);
2270 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2276 FailureOr<Operation *> maybeResult =
2278 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2279 [&rewriter](
auto op) {
2282 if (failed(maybeResult))
2283 return emitDefaultSilenceableFailure(target);
2297 llvm::to_vector(state.getPayloadOps(getTarget()));
2299 bool isMultiwaySplit = getMultiway();
2301 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2303 <<
"requires exactly one target when "
2304 "multiway split is enabled (got "
2305 << llvm::range_size(payload) <<
")";
2310 if (!isMultiwaySplit)
2311 chunkSizes.reserve(payload.size());
2313 if (getDynamicChunkSizes()) {
2315 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2316 chunkSizes = llvm::to_vector(llvm::map_range(
2317 state.getPayloadOps(getDynamicChunkSizes()), [&](
Operation *op) {
2320 diag = emitSilenceableError()
2321 <<
"expected dynamic split point handle to point to a "
2322 "single-result index-typed op";
2323 diag.attachNote(op->getLoc()) <<
"dynamic split point";
2328 chunkSizes = llvm::to_vector(
2329 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2332 if (
diag.isSilenceableFailure())
2337 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2339 <<
"expected the dynamic split point handle to point to as "
2341 << chunkSizes.size() <<
") as the target handle ("
2342 << payload.size() <<
")";
2345 chunkSizes.resize(payload.size(),
2349 auto checkStructuredOpAndDimensions =
2352 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2353 diag.attachNote(loc) <<
"target op";
2357 if (getDimension() >= linalgOp.getNumLoops()) {
2358 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2359 <<
" does not exist in target op";
2360 diag.attachNote(loc) <<
"target op";
2366 auto checkFailureInSplitting =
2370 diag.attachNote(loc) <<
"target op";
2377 if (isMultiwaySplit) {
2380 TilingInterface head, tail;
2383 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2387 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2388 if (
diag.isSilenceableFailure())
2394 target = tail.getOperation();
2399 linalgOp = cast<LinalgOp>(target);
2404 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2405 getDimension(), chunkSize);
2409 checkFailureInSplitting(!head && !tail, loc);
2410 if (
diag.isDefiniteFailure())
2413 opList.push_back(head.getOperation());
2418 opList.push_back(tail.getOperation());
2424 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2427 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2429 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2431 if (
diag.isSilenceableFailure())
2435 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2436 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2437 getDimension(), std::get<1>(pair));
2441 checkFailureInSplitting(!first.back() && !second.back(), loc);
2446 if (!second.back()) {
2447 noSecondPart = target;
2452 if (second.size() != first.size() && !second.empty()) {
2453 auto diag = emitSilenceableError()
2454 <<
"splitting does not produce the second part for a subset "
2457 <<
"expected splitting to produce the second part of all "
2458 "or none of the targets";
2460 <<
"first target with no second part";
2464 opList.append(first);
2466 opList.append(second);
2468 results.
set(cast<OpResult>(getSplitList()), opList);
2472 void SplitOp::getEffects(
2475 if (getDynamicChunkSizes())
2483 IntegerAttr staticChunkSizes;
2489 if (!dynamicPointParseResult.
has_value()) {
2490 int64_t staticChunkSizesValue;
2491 if (failed(parser.
parseInteger(staticChunkSizesValue)))
2504 if (dynamicPointParseResult.
has_value()) {
2505 Type ChunkSizesType;
2506 if (failed(*dynamicPointParseResult) || parser.
parseComma() ||
2518 SplitOp::getStaticChunkSizesAttrName(result.
name).getValue(),
2525 printer <<
" " << getTarget() <<
" after ";
2526 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
2527 if (staticChunkSize != ShapedType::kDynamic)
2528 printer << staticChunkSize;
2530 printer << getDynamicChunkSizes();
2533 {getStaticChunkSizesAttrName()});
2534 printer <<
" : " << getTarget().getType();
2535 if (staticChunkSize == ShapedType::kDynamic)
2536 printer <<
", " << getDynamicChunkSizes().getType();
2540 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2541 (getDynamicChunkSizes() ==
nullptr)) {
2542 return emitOpError() <<
"expects either a dynamic or a static split "
2543 "point to be provided";
2552 void transform::SplitReductionOp::build(
2554 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
2555 bool useScalingAlgorithm,
bool useAlloc) {
2561 SplitReductionOp::getInsertSplitDimensionAttrName(result.
name),
2563 if (innerParallel) {
2564 result.
addAttribute(SplitReductionOp::getInnerParallelAttrName(result.
name),
2567 if (useScalingAlgorithm) {
2569 SplitReductionOp::getUseScalingAlgorithmAttrName(result.
name),
2577 result.
addTypes({resultType, resultType, resultType, resultType});
2586 unsigned(getInsertSplitDimension()),
2587 bool(getInnerParallel())};
2590 FailureOr<SplitReductionResult> splitResult =
2591 (getUseScalingAlgorithm())
2594 if (failed(splitResult))
2595 return emitDefaultDefiniteFailure(target);
2597 results.
push_back(splitResult->initOrAlloc);
2599 results.
push_back(splitResult->splitLinalgOp);
2600 results.
push_back(splitResult->resultCombiningLinalgOp);
2608 void transform::TileReductionUsingForOp::build(
2619 build(builder, result,
2622 staticTileSizesAttr);
2631 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2635 return emitDefaultSilenceableFailure(target);
2636 for (
Value initValue : result->initialValues)
2638 for (
auto parallelTiledOp : result->parallelTiledOps)
2640 for (
auto mergeOp : result->mergeOps)
2642 results.
push_back(result->loops.front());
2650 void transform::TileReductionUsingForallOp::build(
2653 ArrayAttr mapping) {
2663 build(builder, result,
2666 staticNumThreadsAttr,
2667 staticTileSizesAttr,
2680 FailureOr<linalg::ForallReductionTilingResult> result =
2682 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2683 numThreads, tileSizes, getMapping());
2685 if (failed(result)) {
2686 auto diag = emitSilenceableError() <<
"could not tile reduction";
2687 diag.attachNote(target.getLoc()) <<
"target operation";
2690 for (
Value initValue : result->initialValues)
2692 for (
auto parallelTiledOp : result->parallelTiledOps)
2694 for (
auto mergeOp : result->mergeOps)
2710 llvm::to_vector(state.getPayloadOps(getTarget()));
2712 if (!llvm::hasSingleElement(targetOps)) {
2714 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
2719 auto linalgOp = dyn_cast<LinalgOp>(target);
2720 auto tileableOp = dyn_cast<TilingInterface>(target);
2725 OpBuilder builder(linalgOp.getContext());
2727 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
2728 if (linalgOp.hasDynamicShape()) {
2729 auto diag = emitSilenceableError()
2730 <<
"cannot compute parametric tile sizes for dynamically "
2731 "shaped payload op";
2732 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
2736 FailureOr<StaticContinuousTileSizeSpecification> spec =
2740 return emitSilenceableError()
2741 <<
"failed to compute multi-size tiling sizes";
2746 for (
auto &&[tileSize, tripCount] :
2747 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2748 chunkSizes.push_back(tileSize * tripCount);
2751 return llvm::map_to_vector(values, [&](int64_t value) ->
Attribute {
2755 transformResults.
setParams(cast<OpResult>(getTileSizes()),
2756 getI64AttrsFromI64(spec->tileSizes));
2757 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
2758 getI64AttrsFromI64(chunkSizes));
2766 unsigned dimension = getDimension();
2769 builder, tileableOp, dimension, targetSize,
true);
2771 return emitSilenceableError() <<
"could not generate tile size computation";
2783 for (
auto &&[tileSize, tripCount] :
2784 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2785 splitPoint = apply(s0 * s1, {tileSize, tripCount});
2786 chunkSizes.push_back(splitPoint);
2790 return llvm::map_to_vector(values, [&](
Value value) ->
Operation * {
2795 transformResults.
set(cast<OpResult>(getTileSizes()),
2796 getDefiningOps(spec->tileSizes));
2797 transformResults.
set(cast<OpResult>(getChunkSizes()),
2798 getDefiningOps(chunkSizes));
2806 return emitOpError() <<
"expects all results type to be the same";
2812 void transform::ContinuousTileSizesOp::getEffects(
2814 if (isa<TransformParamTypeInterface>(getTileSizes().
getType()))
2830 Type &tileSizesType,
2831 Type &chunkSizesType) {
2832 FunctionType funcType;
2834 if (failed(parser.
parseType<FunctionType>(funcType)))
2837 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
2838 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
2839 "argument and one result";
2841 targetType = funcType.getInput(0);
2842 tileSizesType = chunkSizesType = funcType.getResult(0);
2851 void transform::TileUsingForOp::build(
2856 return build(builder, result, loopTypes,
2860 interchange, scalableSizes);
2863 void transform::TileUsingForOp::build(
2867 build(builder, result, target,
2869 interchange, scalableSizes);
2872 void transform::TileUsingForOp::build(
2879 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2883 void transform::TileUsingForOp::build(
2895 unsigned numExpectedLoops =
2896 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2898 resultTypes.reserve(numExpectedLoops);
2899 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2900 "expected one loop type or as many as loops");
2901 if (loopTypes.size() == 1)
2902 resultTypes.append(numExpectedLoops, loopTypes[0]);
2904 llvm::append_range(resultTypes, loopTypes);
2906 if (scalableSizes.has_value())
2907 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2908 build(builder, result, target.
getType(),
2912 staticTileSizesAttr,
2914 expandedScalableSizes);
2919 return emitOpError(
"expected same number of sizes (")
2921 << getScalableSizes().size() <<
")";
2923 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
2924 if (getLoops().size() != numExpectedLoops)
2925 return emitOpError(
"expected number of loops to tile (")
2926 << numExpectedLoops <<
") to match number of `loops` results ("
2927 << getLoops().size() <<
")";
2938 llvm::to_vector(state.getPayloadOps(getTarget()));
2944 if (isa<ParamType>(transformValue.getType())) {
2945 dynamicSizeProducers.push_back({});
2947 paramSizes.push_back(
2948 llvm::to_vector(llvm::map_range(params, [](
Attribute attr) {
2949 return cast<IntegerAttr>(attr).getValue().getSExtValue();
2952 if (paramSizes.back().size() != targets.size()) {
2954 emitSilenceableError()
2955 <<
"expected as many parameter values ("
2956 << dynamicSizeProducers.back().size() <<
") as target ops ("
2957 << targets.size() <<
")";
2958 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
2964 paramSizes.push_back({});
2965 dynamicSizeProducers.push_back(
2966 llvm::to_vector(state.getPayloadOps(transformValue)));
2968 if (dynamicSizeProducers.back().size() != targets.size()) {
2970 emitSilenceableError()
2971 <<
"expected as many dynamic size-producing operations ("
2972 << dynamicSizeProducers.back().size() <<
") as target ops ("
2973 << targets.size() <<
")";
2974 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
2978 for (
Operation *op : dynamicSizeProducers.back()) {
2985 emitSilenceableError() <<
"expected sizes to be produced by ops "
2986 "with a single index-type result";
2987 diag.attachNote(op->
getLoc()) <<
"size producer op";
2988 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
2995 loops.resize(getLoops().size());
2996 auto scalableSizes = getScalableSizes();
2998 auto tilingInterface = dyn_cast<TilingInterface>(op);
2999 if (!tilingInterface) {
3001 emitSilenceableError()
3002 <<
"only ops implementing TilingInterface are supported";
3003 diag.attachNote(op->
getLoc()) <<
"target op";
3006 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3008 emitSilenceableError()
3009 <<
"too many tiles provided, expected at most "
3010 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
3011 << tileSizes.size();
3012 diag.attachNote(op->
getLoc()) <<
"target op";
3017 if (tileSizes.empty()) {
3026 sizes.reserve(tileSizes.size());
3027 unsigned dynamicIdx = 0;
3030 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3031 if (scalableSizes[ofrIdx]) {
3032 auto val = b.
create<arith::ConstantIndexOp>(
3033 getLoc(), cast<IntegerAttr>(attr).getInt());
3039 sizes.push_back(attr);
3046 assert((dynamicSizes.empty() ^ params.empty()) &&
3047 "expected either dynamic sizes or parameters");
3048 if (!params.empty()) {
3051 sizes.push_back(dynamicSizes[index]->getResult(0));
3059 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3060 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3061 if (failed(maybeTilingResult))
3064 rewriter.
replaceOp(op, maybeTilingResult->replacements);
3066 tiled.append(maybeTilingResult->tiledOps);
3068 loops[en2.index()].push_back(en2.value());
3071 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3073 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3082 results.reserve(tileSizes.size());
3083 unsigned dynamicPos = 0;
3085 for (int64_t size : tileSizes) {
3086 if (size == ShapedType::kDynamic) {
3087 results.push_back(dynamic[dynamicPos++]);
3095 void transform::TileUsingForOp::getEffects(
3107 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3111 ArrayAttr mapping) {
3112 return build(builder, result,
3120 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3124 ArrayAttr mapping) {
3134 build(builder, result,
3135 TypeRange{operationType, operationType},
3142 staticTileSizesAttr,
3146 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3150 ArrayAttr mapping) {
3151 return build(builder, result, target,
3156 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3160 ArrayAttr mapping) {
3171 build(builder, result,
3172 TypeRange{operationType, operationType},
3178 staticNumThreadsAttr,
3191 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3193 for (
auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3195 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3196 normalizedUbs.push_back(normalizedUb);
3198 return normalizedUbs;
3214 for (
auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3217 denormalizedIvs.push_back(
3220 return denormalizedIvs;
3231 scf::ForallOp loop) {
3251 auto normalizedForallOp = rewriter.
create<scf::ForallOp>(
3252 loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3255 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3257 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3262 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3263 normalizedForallOp.getRegionIterArgs().end());
3264 Block *origLoopBlock = loop.getBody();
3265 rewriter.
mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3267 rewriter.
replaceOp(loop, normalizedForallOp);
3268 return normalizedForallOp;
3273 TransformOpInterface transformOp,
Operation *target,
3278 auto tileableOp = dyn_cast<TilingInterface>(target);
3281 transformOp.emitSilenceableError()
3282 <<
"only TilingInterface ops are supported";
3283 diag.attachNote(target->
getLoc()) <<
"target op";
3289 if (!mixedNumThreads.empty()) {
3290 options.setNumThreads(mixedNumThreads);
3292 options.setTileSizes(mixedTileSizes);
3295 options.setMapping(mapping.value().getValue());
3297 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3300 if (failed(maybeTilingResult))
3301 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3303 rewriter.
replaceOp(tileableOp, maybeTilingResult->replacements);
3305 tilingResult = *maybeTilingResult;
3307 if (mixedNumThreads.empty()) {
3308 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.
loops.front());
3311 scf::ForallOp normalizedForallOp =
3313 tilingResult.
loops.front() = normalizedForallOp;
3323 auto transformOp = cast<TransformOpInterface>(getOperation());
3332 getPackedNumThreads()
3334 state, transformOp, mixedNumThreads, getPackedNumThreads())
3336 state, transformOp, mixedNumThreads, getMixedNumThreads());
3340 status = getPackedTileSizes()
3342 state, transformOp, mixedTileSizes, getPackedTileSizes())
3344 state, transformOp, mixedTileSizes, getMixedTileSizes());
3348 for (
Operation *target : state.getPayloadOps(getTarget())) {
3351 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3352 getMapping(), tilingResult);
3353 if (!
diag.succeeded())
3355 tileOps.push_back(tilingResult.
loops.front());
3356 tiledOps.append(tilingResult.
tiledOps);
3359 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3360 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3365 void transform::TileUsingForallOp::getEffects(
3378 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3387 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3388 static_cast<int>(getPackedNumThreads() !=
Value());
3389 if (numThreadsSpec > 1)
3391 "num_threads and packed_num_threads are mutually exclusive");
3392 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3393 static_cast<int>(getPackedTileSizes() !=
Value());
3394 if (tileSizesSpec > 1)
3396 "tile_sizes and packed_tile_sizes are mutually exclusive");
3397 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3398 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
3399 "must be specified");
3407 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3409 bool vectorizePadding,
bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
3411 if (vectorizePadding) {
3413 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3417 if (vectorizeExtract) {
3419 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3423 if (flatten1DDepthwiseConv) {
3425 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3436 explicit VectorizationPattern(
MLIRContext *context,
3437 bool vectorizeExtract =
false,
3438 bool flattenConv =
false)
3440 vectorizeNDExtract(vectorizeExtract),
3441 flatten1DDepthwiseConv(flattenConv) {}
3442 LogicalResult matchAndRewrite(
Operation *op,
3446 "Unsupported Op, cannot vectorize");
3448 {}, vectorizeNDExtract,
3449 flatten1DDepthwiseConv);
3455 bool vectorizeNDExtract =
false;
3459 bool flatten1DDepthwiseConv =
false;
3464 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3469 auto diag = this->emitOpError(
"requires isolated-from-above targets");
3470 diag.attachNote(target->
getLoc()) <<
"non-isolated target";
3476 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3477 getFlatten_1dDepthwiseConv());
3479 if (!getDisableTransferPermutationMapLoweringPatterns())
3482 if (!getDisableMultiReductionToContractPatterns())
3490 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3491 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3499 if (getVectorizePadding()) {
3511 return emitDefaultDefiniteFailure(target);
3525 auto targets = state.getPayloadOps(getTarget());
3526 if (std::empty(targets))
3528 auto transformOp = cast<TransformOpInterface>(getOperation());
3531 state, transformOp, getMixedVectorSizes(), vectorSizes);
3539 <<
"Unsupported Op, cannot vectorize";
3544 getVectorizeNdExtract().value_or(
false)))) {
3546 <<
"Attempted to vectorize, but failed";
3553 void transform::VectorizeOp::getEffects(
3562 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3566 if (getStaticVectorSizes().size() != getScalableSizes().size())
3567 return emitOpError(
"expected same number of vector sizes (")
3568 << getStaticVectorSizes().size() <<
") and scalable sizes ("
3569 << getScalableSizes().size() <<
")";
3578 transform::HoistRedundantVectorTransfersOp::applyToOne(
3595 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3614 auto maybeTransformed =
3617 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3620 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3623 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3626 .Case([&](linalg::Conv2DNchwFchwOp op) {
3632 if (failed(maybeTransformed))
3633 return emitDefaultSilenceableFailure(target);
3635 results.
push_back(maybeTransformed->first);
3637 results.
push_back(maybeTransformed->second);
3652 <<
"only elementwise flattening is supported";
3655 if (target.getNumLoops() <= 1) {
3662 std::iota(reassociation.begin(), reassociation.end(), 0);
3663 auto maybeFlattened =
3665 if (failed(maybeFlattened))
3667 <<
"attempted to flatten, but failed";
3668 results.
push_back(maybeFlattened->collapsedOp);
3669 rewriter.
replaceOp(target, maybeFlattened->results);
3682 auto maybeTransformed =
3684 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3687 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3693 if (failed(maybeTransformed))
3694 return emitDefaultSilenceableFailure(target);
3709 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3710 auto maybeTransformed =
3712 .Case([&](linalg::MatmulOp op) {
3715 .Case([&](linalg::BatchMatmulOp op) {
3718 .Default([&](
Operation *op) {
return failure(); });
3719 if (failed(maybeTransformed))
3729 template <
typename OpTy>
3733 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3734 tensor::ParallelInsertSliceOp>() &&
3737 if (
auto copySource =
3738 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3745 if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3747 target->template getParentOfType<scf::InParallelOp>());
3750 Value extracted = rewriter.
create<tensor::ExtractSliceOp>(
3751 target.getLoc(), target.getDest(), target.getMixedOffsets(),
3752 target.getMixedSizes(), target.getMixedStrides());
3753 Value copied = rewriter
3754 .
create<linalg::CopyOp>(target.getLoc(),
3755 target.getSource(), extracted)
3760 target, copied, target.getDest(), target.getMixedOffsets(),
3761 target.getMixedSizes(), target.getMixedStrides());
3763 results.
push_back(copied.getDefiningOp());
3773 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3774 return doit(rewriter, target, results, state);
3775 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3776 return doit(rewriter, target, results, state);
3779 emitSilenceableError()
3780 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3781 diag.attachNote(targetOp->
getLoc()) <<
"target op";
3794 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3796 emitSilenceableError()
3797 <<
"only linalg.copy and tensor.pad target ops are supported";
3798 diag.attachNote(target->
getLoc()) <<
"target op";
3801 assert(target->
getNumResults() == 1 &&
"expected single result");
3803 if (!resultShapedType.hasStaticShape()) {
3805 emitSilenceableError()
3806 <<
"only statically sized ops of rank <= 3 are supported";
3807 diag.attachNote(target->
getLoc()) <<
"target op";
3812 int64_t desiredBitAlignment = getDesiredBitAlignment();
3813 int64_t eltBitwidth =
3814 resultShapedType.getElementType().getIntOrFloatBitWidth();
3815 if (desiredBitAlignment % eltBitwidth != 0) {
3816 desiredBitAlignment = eltBitwidth;
3821 getTotalNumThreads(),
3822 desiredBitAlignment,
3823 resultShapedType.getShape(),
3826 resultShapedType.getElementType().getIntOrFloatBitWidth());
3827 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3829 emitSilenceableError()
3830 <<
"too few threads to map copy op to threads on the most minor "
3831 "dimension, given alignment and vector size constraints, try "
3832 "smaller tile size of mapping to more threads";
3833 diag.attachNote(target->
getLoc()) <<
"target op";
3849 if (!
diag.succeeded())
3853 for (
auto op : tilingResult.
tiledOps)
3867 FailureOr<Operation *> maybeTransformed = failure();
3869 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3874 .Default([&](
Operation *op) {
return false; });
3877 return emitSilenceableError()
3878 <<
"this operation is not supported to convert to Winograd Conv2D";
3881 if (supported && failed(maybeTransformed)) {
3882 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
3894 FailureOr<Operation *> maybeTransformed = failure();
3897 .Case([&](linalg::WinogradFilterTransformOp op) {
3901 .Case([&](linalg::WinogradInputTransformOp op) {
3905 .Case([&](linalg::WinogradOutputTransformOp op) {
3909 .Default([&](
Operation *op) {
return false; });
3913 emitSilenceableError()
3914 <<
"this operation is not supported to decompose into other operations";
3915 diag.attachNote(target->
getLoc()) <<
"target op";
3919 if (supported && failed(maybeTransformed)) {
3921 emitSilenceableError() <<
"decompose Winograd operations failed";
3922 diag.attachNote(target->
getLoc()) <<
"target op";
3930 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3932 #define GET_OP_CLASSES
3933 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
This class represents a saved insertion point.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OpResult getOpResult(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns)
Populates patterns with vectorisation patterns for tensor.insert_slice.
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
uint64_t getM(LevelType lt)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
@ MaterializeInDestination
Transformation to drop unit-extent dimensions from linalg.generic operations.
Vectorization pattern for memref::CopyOp.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Match and rewrite for the pattern:
Match and rewrite for the pattern:
@ BufferizationMaterializeInDestination
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.