43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/SmallPtrSet.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/DebugLog.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include <type_traits>
55 #define DEBUG_TYPE "linalg-transforms"
62 template <
typename PatternTy,
typename... Args>
65 using OpTy =
typename llvm::function_traits<
66 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
67 auto op = dyn_cast<OpTy>(operation);
72 PatternTy pattern(operation->
getContext(), std::forward<Args>(args)...);
77 auto result = pattern.returningMatchAndRewrite(op, rewriter);
80 return cast<LinalgOp>(result->getOperation());
90 if (
auto attr = dyn_cast<Attribute>(ofr)) {
91 if (!isa<IntegerAttr>(attr))
92 return transformOp.emitDefiniteFailure() <<
"expected IntegerAttr";
93 result.push_back(ofr);
97 Value transformValue = cast<Value>(ofr);
98 if (isa<TransformParamTypeInterface>(transformValue.
getType())) {
100 if (params.size() != 1)
101 return transformOp.emitDefiniteFailure()
102 <<
"requires exactly one parameter associated";
103 result.push_back(params[0]);
107 auto payloadOps = state.getPayloadOps(transformValue);
108 if (!llvm::hasSingleElement(payloadOps)) {
110 transformOp.emitSilenceableError()
111 <<
"handle must be mapped to exactly one payload op";
113 <<
"mapped to " << llvm::range_size(payloadOps) <<
" payload ops";
120 transformOp.emitSilenceableError()
121 <<
"payload op must have exactly 1 index result";
141 if (isa<TransformParamTypeInterface>(packedHandle.
getType())) {
143 for (
auto param : params) {
144 if (!isa<IntegerAttr>(param))
145 return transformOp.emitDefiniteFailure()
146 <<
"expected the parameter to be associated with an integer "
148 result.push_back(param);
153 for (
Operation *op : state.getPayloadOps(packedHandle)) {
154 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
156 transformOp.emitSilenceableError()
157 <<
"payload op must have exactly 1 index result";
158 diag.attachNote(op->getLoc())
159 <<
"has " << op->getNumResults() <<
" results";
162 result.push_back(op->getResult(0));
176 if (
auto attr = dyn_cast<Attribute>(paramOrHandle)) {
177 reified.push_back(cast<IntegerAttr>(attr).getInt());
179 }
else if (isa<ParamType>(cast<Value>(paramOrHandle).
getType())) {
181 if (params.size() != 1)
182 return transformOp.emitSilenceableError() <<
"expected a single param";
184 cast<IntegerAttr>(params.front()).getValue().getSExtValue());
188 Value handle = cast<Value>(paramOrHandle);
189 if (!isa<TransformHandleTypeInterface>(handle.
getType()))
190 return transformOp.emitSilenceableError() <<
"unexpected value handle";
191 auto payload = state.getPayloadOps(handle);
192 if (!llvm::hasSingleElement(payload))
193 return transformOp.emitSilenceableError()
194 <<
"requires param or handle that is mapped to 1 payload op";
196 Operation *paramOrHandlePayloadOp = *payload.begin();
199 return transformOp.emitSilenceableError()
200 <<
"requires param or handle to be result of op with 1 index "
206 return transformOp.emitSilenceableError()
207 <<
"requires param or handle to be the result of a constant like "
210 reified.push_back(attr.getInt());
219 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
224 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
229 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
234 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
240 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
243 options.rankReductionStrategy =
248 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
253 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
258 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
263 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
268 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
287 void notifyOperationInserted(
Operation *op,
289 ForwardingListener::notifyOperationInserted(op, previous);
291 if (previous.
isSet())
293 auto inserted = newOps.insert(op);
295 assert(inserted.second &&
"expected newly created op");
298 void notifyOperationErased(
Operation *op)
override {
299 ForwardingListener::notifyOperationErased(op);
313 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
314 NewOpsListener newOpsListener(previousListener);
318 if (getMemcpyOp() ==
"bufferization.materialize_in_destination") {
321 }
else if (getMemcpyOp() ==
"memref.copy") {
324 }
else if (getMemcpyOp() ==
"linalg.copy") {
328 llvm_unreachable(
"invalid memcpy op");
330 if (getAllocOp() ==
"memref.alloc") {
333 }
else if (getAllocOp() ==
"memref.alloca") {
337 llvm_unreachable(
"invalid alloc op");
339 options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
340 options.emitDealloc = getEmitDealloc();
344 getMemorySpace().has_value() ? getMemorySpace().value() :
Attribute();
346 for (
Operation *op : state.getPayloadOps(getTarget())) {
351 <<
"failed to bufferize operation";
352 diag.attachNote(op->
getLoc()) <<
"target payload op";
355 allocatedBuffers.push_back(buffer);
359 results.
setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
360 results.
set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
364 void transform::BufferizeToAllocationOp::getEffects(
366 if (getBufferizeDestinationOnly()) {
378 if (getMemcpyOp() !=
"bufferization.materialize_in_destination" &&
379 getMemcpyOp() !=
"memref.copy" && getMemcpyOp() !=
"linalg.copy")
380 return emitOpError() <<
"unsupported memcpy op";
381 if (getAllocOp() !=
"memref.alloc" && getAllocOp() !=
"memref.alloca")
382 return emitOpError() <<
"unsupported alloc op";
394 auto linalgOp = dyn_cast<linalg::LinalgOp>(operand.
getOwner());
401 Value blockArgument = linalgOp.getMatchingBlockArgument(&operand);
409 if (!isa<TensorType, FloatType, IntegerType>(value.
getType()))
411 return llvm::any_of(value.
getUses(),
420 for (
Value tensor : state.getPayloadValues(getTensor())) {
421 auto type = dyn_cast<RankedTensorType>(tensor.getType());
423 return emitSilenceableError() <<
"non-tensor type: " << tensor;
426 Operation *definingOp = tensor.getDefiningOp();
433 bool needsMaterialization =
mayBeRead(tensor);
438 if (!ShapedType::isDynamic(dim))
440 Value cst = rewriter.
create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
441 auto dimOp = rewriter.
create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
442 preservedOps.insert(dimOp);
443 dynamicDims.push_back(dimOp);
445 auto allocation = rewriter.
create<bufferization::AllocTensorOp>(
446 tensor.getLoc(), type, dynamicDims);
448 if (getMemorySpaceAttr())
449 allocation.setMemorySpaceAttr(getMemorySpaceAttr());
450 Value allocated = allocation;
454 if (needsMaterialization) {
455 auto copy = rewriter.
create<bufferization::MaterializeInDestinationOp>(
456 tensor.getLoc(), tensor, allocated);
457 preservedOps.insert(
copy);
458 promoted.push_back(
copy.getResult());
460 promoted.push_back(allocated);
464 results.
setValues(cast<OpResult>(getPromoted()), promoted);
468 void transform::PromoteTensorOp::getEffects(
484 #define DOWNSCALE(trans) \
486 FailureOr<LinalgOp> res = tryApply<trans>(target); \
487 if (succeeded(res)) { \
488 results.push_back(*res); \
489 return DiagnosedSilenceableFailure::success(); \
493 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
494 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
507 #undef DOWNSCALE_NORMAL
508 #undef DOWNSCALE_CALL
510 return emitDefaultSilenceableFailure(target);
524 auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
525 if (!decomposableOp) {
527 "payload is not a decomposable op"));
528 return emitDefaultSilenceableFailure(target);
531 FailureOr<SmallVector<Value>> maybeNewResults =
532 decomposableOp.decomposeOperation(rewriter);
533 if (
failed(maybeNewResults))
534 return emitDefaultSilenceableFailure(target);
536 rewriter.
replaceOp(decomposableOp, *maybeNewResults);
537 for (
Value val : *maybeNewResults) {
538 Operation *definition = val.getDefiningOp();
549 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
556 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
560 options.allowReturnAllocsFromLoops =
true;
562 for (
Operation *target : state.getPayloadOps(getTarget())) {
566 <<
"failed to analyze op";
568 rewriter, target, state)))
570 <<
"failed to eliminate LinalgOp anchored tensor.empty ops";
583 bool applyCleanup,
bool useForall) {
585 builder, result, loopTypes,
591 applyCleanup, useForall);
597 bool applyCleanup,
bool useForall) {
605 applyCleanup, useForall);
612 bool applyCleanup,
bool useForall) {
616 build(builder, result, loopTypes, target, mixedTileSizes,
617 mixedTileInterchange, applyCleanup, useForall);
624 bool applyCleanup,
bool useForall) {
631 staticTileInterchange);
636 auto staticTileInterchangeAttr =
638 unsigned numExpectedLoops =
639 useForall ? 1 : staticTileSizes.size() - llvm::count(staticTileSizes, 0);
641 resultTypes.reserve(numExpectedLoops);
642 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
643 "expected one loop type or as many as loops");
644 if (loopTypes.size() == 1)
645 resultTypes.append(numExpectedLoops, loopTypes[0]);
647 llvm::append_range(resultTypes, loopTypes);
648 build(builder, result, target.
getType(),
652 dynamicTileInterchange,
654 staticTileInterchangeAttr,
661 template <
typename Range>
665 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
671 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
672 if (!tilingInterfaceOp)
673 return transformOp->
emitError(
"only TilingInterface ops are supported");
676 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
677 applyFn(tilingInterfaceOp);
683 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
684 for (
Operation *toReplace : opsToReplace) {
685 for (
OpResult res : toReplace->getResults())
686 if (
auto replacement = tiledResults->replacements.lookup(res))
688 if (toReplace->use_empty()) {
694 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
695 assert(tiledResults->loops.size() == numLoops &&
696 "Mismatched number of loops, tile and fuse transform should have "
698 for (
unsigned int i = 0; i < numLoops; ++i)
699 loopOps[i].push_back(tiledResults->loops[i]);
702 transformResults.
set(transformOp->
getOpResult(0), tiledLinalgOps);
703 for (
unsigned int i = 0; i < numLoops; ++i)
704 transformResults.
set(transformOp->
getOpResult(i + 1), loopOps[i]);
713 auto transformOp = cast<TransformOpInterface>(getOperation());
717 state, transformOp, getMixedTileSizes(), tileSizes);
722 state, transformOp, getMixedTileInterchange(), tileInterchange);
728 bool useForall = getUseForall();
734 tilingOptions = tilingOptions.
setTileSizes(tileSizesOfr);
738 if (getApplyCleanup()) {
741 tensor::ExtractSliceOp::getCanonicalizationPatterns(
patterns, context);
748 useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
750 rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
752 [&](TilingInterface tilingInterfaceOp)
753 -> FailureOr<scf::SCFTileAndFuseResult> {
762 auto iterspace_rank = getStaticTileSizes().size();
764 if (permutation.size() > iterspace_rank)
766 <<
"interchange length exceeds iteration space dimensions ("
767 << iterspace_rank <<
"), found " << getTileInterchange();
769 for (int64_t v : permutation) {
770 if (!ShapedType::isDynamic(v)) {
771 if (v < 0 || v >=
static_cast<int64_t
>(iterspace_rank))
772 return emitOpError() <<
"expects interchange values to be in range [0, "
773 << iterspace_rank <<
"), found: " << v;
775 return emitOpError() <<
"found duplicate interchange value: " << v;
781 size_t numExpectedLoops =
782 getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
783 if (numExpectedLoops != getNumResults() - 1)
784 return emitOpError() <<
"expects " << numExpectedLoops <<
" loop results";
794 return getMixedValues(getStaticTileInterchange(), getTileInterchange(),
798 void transform::FuseOp::getEffects(
811 void transform::FuseIntoContainingOp::build(
OpBuilder &builder,
814 Value containingOp) {
817 result.
addTypes({resultType, resultType});
833 (domInfo.
dominates(containingOp, user))) {
834 dominatedUsers.insert(user);
837 if (dominatedUsers.empty())
841 auto forallOp = cast<scf::ForallOp>(containingOp);
847 auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
852 newOuts.push_back(outputs[resultNumber]);
855 auto newforallOp = scf::ForallOp::create(
856 rewriter, loc, forallOp.getMixedLowerBound(),
857 forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
858 forallOp.getMapping());
860 newforallOp.getRegion().takeBody(forallOp.getRegion());
865 newforallOp.getBody()->addArgument(newOuts.back().getType(),
866 newOuts.back().getLoc());
867 auto bbArgs = newforallOp.getBody()->getArguments();
870 Operation *op = use.getOwner();
871 return newforallOp->isProperAncestor(op);
875 scf::InParallelOp terminatorOp = newforallOp.getTerminator();
877 terminatorOp.getYieldingOps(), [](
Operation &op) { return &op; }));
878 Operation *firstYieldOp = yieldingOps.front();
881 Value dst = newforallOp.getRegionIterArgs().back();
883 tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->
getLoc(), src,
884 dst, offsets, sizes, strides);
888 newforallOp->getResult(result.index()));
891 newforallOp->getResults().back(),
893 Operation *user = use.getOwner();
894 return dominatedUsers.contains(user);
908 destWorklist.push_back(dst);
910 while (!destWorklist.empty()) {
911 Value currentDst = destWorklist.pop_back_val();
915 if (src == currentDst)
920 auto bbArg = dyn_cast<BlockArgument>(currentDst);
924 Block *parentBlock = bbArg.getOwner();
925 assert(parentBlock &&
"unlinked block argument");
928 assert(parentOp &&
"expected block argument with parent operation");
931 auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
935 for (
auto innerIterArg : parentLoop.getRegionIterArgs()) {
937 OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
938 Value loopBlockArgument =
940 destWorklist.push_back(loopBlockArgument);
953 static std::tuple<SmallVector<Operation *>,
Operation *>
956 LDBG() <<
"Try to fuse a direct extract use";
957 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
958 if (!tileableProducer) {
960 <<
"producer is not a TileableInterface: " << *producerOp;
967 auto it = llvm::find_if(tileableProducer->getUsers(), [&](
Operation *user) {
968 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
969 return sliceOp && containingOp->isProperAncestor(sliceOp);
973 if (it == tileableProducer->getUsers().end()) {
974 diag.attachNote(tileableProducer->getLoc())
975 <<
"could not find fusion opportunity for: " << *tileableProducer;
978 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
991 if (LoopLikeOpInterface containerLoop =
992 dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
999 cast<DestinationStyleOpInterface>(
clone).getDpsInitsMutable()) {
1000 Value producerOperand =
1003 containerLoop.getRegionIterArgs()) {
1004 OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
1005 Value consumerOperand =
1006 containerLoop->getOperand(bbArg->getOperandNumber());
1008 if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
1009 initOperandPtr.set(containerIterArg);
1015 tileableProducer = dyn_cast<TilingInterface>(
clone);
1019 int64_t resultNumber =
1020 cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
1021 LDBG() <<
"resultNumber: " << resultNumber;
1026 FailureOr<TilingResult> tileAndFuseResult =
1027 tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
1030 if (
failed(tileAndFuseResult)) {
1031 diag.attachNote(tileableProducer->getLoc())
1032 <<
"failed to tile producer op: " << *tileableProducer;
1037 for (
auto *tiledOp : tileAndFuseResult->tiledOps) {
1038 LDBG() <<
"tiledProducer: " << *tiledOp;
1043 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1044 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1045 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1046 if (
failed(maybeRankReduced)) {
1048 <<
"shape types don't match (missing canonicalization?):\nTiledOp: "
1049 << tileAndFuseResult->tiledValues[0]
1050 <<
"\nSliceOp: " << sliceOpToTile.getOperation() <<
'\n';
1053 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
1057 rewriter,
diag, producerOp, containingOp, *tileAndFuseResult,
1058 resultNumber, offsets, sizes);
1061 if (dyn_cast<LoopLikeOpInterface>(containingOp))
1062 rewriter.
eraseOp(tileableProducer);
1064 return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
1077 LDBG() <<
"Try to fuse an extract use through block argument";
1079 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
1080 if (!tileableProducer) {
1082 <<
"producer is not a TileableInterface: " << *producerOp;
1087 scf::ForallOp forallOp;
1088 auto itProducerUses =
1089 llvm::find_if(tileableProducer->getUses(), [&](
OpOperand &use) {
1090 forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
1094 if (!forallOp || forallOp != containingOp) {
1095 diag.attachNote(tileableProducer->getLoc())
1096 <<
"could not find a use by the containing op: " << *tileableProducer;
1111 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
1112 return sliceOp && containingOp->isProperAncestor(sliceOp);
1116 if (itBBArgUsers == bbArg.
getUsers().end()) {
1118 <<
"could not find fusion opportunity for bbArg: " << bbArg;
1121 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
1129 int64_t resultNumber = cast<OpResult>(pUse->
get()).getResultNumber();
1130 LDBG() <<
"resultNumber: " << resultNumber;
1135 rewriter, tileableProducer->getLoc(), tileableProducer,
1136 destinationTensors))) {
1137 diag.attachNote(tileableProducer->getLoc())
1138 <<
"failed to get destination tensors for: " << *tileableProducer;
1143 bvm.
map(destinationTensors[resultNumber], bbArg);
1144 auto tileableProducerClone =
1145 cast<TilingInterface>(rewriter.
clone(*tileableProducer, bvm));
1147 llvm::make_scope_exit([&]() { rewriter.
eraseOp(tileableProducerClone); });
1150 FailureOr<TilingResult> tileAndFuseResult =
1151 tileableProducerClone.generateResultTileValue(
1152 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
1153 sliceOpToTile.getMixedSizes());
1154 if (
failed(tileAndFuseResult)) {
1155 diag.attachNote(tileableProducer->getLoc())
1156 <<
"failed to tile producer op: " << *tileableProducer;
1161 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
1162 rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
1163 cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
1164 assert(succeeded(maybeRankReduced) &&
"unexpected shape");
1165 rewriter.
replaceOp(sliceOpToTile, *maybeRankReduced);
1170 destinationTensors.front());
1173 return tileAndFuseResult->tiledOps;
1179 LDBG() <<
"Try to fuse an use by cloning";
1184 for (
OpOperand &use : result.getUses()) {
1186 uses.push_back(&use);
1191 if (containingOp == use.getOwner()) {
1193 <<
"producer op use by containing op cannot be fused by cloning";
1201 diag.attachNote(producerOp->
getLoc()) <<
"no fusion opportunity by cloning";
1210 assert(!isa<tensor::ParallelInsertSliceOp>(use->
getOwner()) &&
1211 "Parallel insert slice is not a valid clone destination");
1212 unsigned resultNumber = cast<OpResult>(use->
get()).getResultNumber();
1213 LDBG() <<
"resultNumber: " << resultNumber;
1217 fusedOp = rewriter.
clone(*producerOp);
1219 use->
getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1224 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1234 auto producerOps = state.getPayloadOps(getProducerOp());
1235 auto containingOps = state.getPayloadOps(getContainingOp());
1236 if (!llvm::hasSingleElement(containingOps)) {
1238 <<
"requires exactly one containing_op handle (got "
1239 << llvm::range_size(containingOps) <<
")";
1241 Operation *containingOp = *containingOps.begin();
1244 if (std::empty(producerOps)) {
1246 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1253 auto getNextProducer = [&]() -> FailureOr<Operation *> {
1254 for (
const auto &it :
enumerate(remainingProducers)) {
1257 int64_t numUsesInContainingOp =
1259 return containingOp->isAncestor(op);
1264 if (numUsesInContainingOp > 0) {
1265 if (numUsesInContainingOp == 1)
1266 remainingProducers.erase(remainingProducers.begin() + it.index());
1273 while (!remainingProducers.empty()) {
1274 auto nextProducer = getNextProducer();
1275 if (
failed(nextProducer)) {
1277 <<
"could not find next producer to fuse into container";
1278 diag.attachNote(containingOp->
getLoc()) <<
"containing op";
1286 diag <<
"could not fuse " << *producerOp <<
" into " << *containingOp;
1293 auto [tiledOps, newContainingOp] =
1295 if (!tiledOps.empty()) {
1296 LDBG() <<
"\nFused a direct extract use\n" << *containingOp;
1297 fusedOps.append(tiledOps);
1298 if (newContainingOp) {
1306 LogicalResult replacementStatus =
1309 (void)replacementStatus;
1310 assert(succeeded(replacementStatus) &&
1311 "unable to update transform state mapping");
1312 rewriter.
eraseOp(containingOp);
1313 containingOp = newContainingOp;
1320 rewriter,
diag, producerOp, containingOp);
1321 if (!tiledContainingOpOperand.empty()) {
1322 LDBG() <<
"\nFused an extract use through block argument\n"
1324 fusedOps.append(tiledContainingOpOperand);
1331 LDBG() <<
"\nFused an use by cloning\n" << *containingOp;
1332 fusedOps.push_back(cloned);
1338 results.
set(cast<OpResult>(getFusedOp()), fusedOps);
1339 results.
set(cast<OpResult>(getNewContainingOp()), {containingOp});
1343 void transform::FuseIntoContainingOp::getEffects(
1361 if (isa<GenericOp>(target)) {
1367 if (succeeded(
generic)) {
1368 results.
push_back(generic->getOperation());
1371 return emitDefaultSilenceableFailure(target);
1384 if (!isa<GenericOp>(target)) {
1389 FailureOr<LinalgOp> named =
1391 if (succeeded(named)) {
1392 results.
push_back(named->getOperation());
1395 return emitDefaultSilenceableFailure(target);
1409 if (interchangeVector.empty()) {
1414 unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1415 if (interchangeVector.size() != numLoops) {
1416 return emitSilenceableError()
1417 << getIteratorInterchangeAttrName() <<
" has length ("
1418 << interchangeVector.size()
1419 <<
") different from the number of loops in the target operation ("
1432 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1433 if (!std::is_permutation(sequence.begin(), sequence.end(),
1434 permutation.begin(), permutation.end())) {
1435 return emitOpError()
1436 <<
"expects iterator_interchange to be a permutation, found "
1437 << getIteratorInterchange();
1452 if (!isa<linalg::CopyOp>(targetOp)) {
1454 emitSilenceableError() <<
"only linalg.copy target ops are supported";
1455 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1459 auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1460 if (!copyOp.hasPureBufferSemantics()) {
1462 emitSilenceableError()
1463 <<
"cannot transform a linalg.copy on tensors into a memref.copy";
1464 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1470 assert(inputs.size() == 1 &&
"expected linalg copy op with one input");
1471 assert(outputs.size() == 1 &&
"expected memref copy op with one output");
1472 Value input = inputs.front();
1473 Value output = outputs.front();
1478 if (!isa<ShapedType>(input.
getType())) {
1480 emitSilenceableError()
1481 <<
"cannot transform a linalg.copy which input has no shape";
1482 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1487 assert(isa<ShapedType>(output.getType()));
1489 if (cast<ShapedType>(input.
getType()).getElementType() !=
1490 cast<ShapedType>(output.getType()).getElementType()) {
1492 emitSilenceableError()
1493 <<
"cannot transform a linalg.copy with different source and "
1494 "destination element types ";
1495 diag.attachNote(targetOp->
getLoc()) <<
"target op";
1516 bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1517 FailureOr<LowerPackResult> res =
1518 lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1521 <<
"cannot lower to pad + expand + transpose";
1524 transformResults.
push_back(res->expandShapeOp);
1525 transformResults.
push_back(res->transposeOp);
1538 bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1539 FailureOr<LowerUnPackOpResult> res =
1540 lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1543 emitSilenceableError()
1544 <<
"cannot lower to transpose + collapse + extract";
1545 diag.attachNote(target->getLoc()) <<
"target payload op";
1548 transformResults.
push_back(res->emptyOp);
1549 transformResults.
push_back(res->transposeOp);
1550 transformResults.
push_back(res->collapseShapeOp);
1551 transformResults.
push_back(res->extractSliceOp);
1581 if (getOps().has_value())
1582 strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1584 auto payloadOps = state.getPayloadOps(getTarget());
1585 if (!llvm::hasSingleElement(payloadOps)) {
1590 bool incorrectNumOperandTypes =
false;
1597 if (getInterface().has_value()) {
1598 auto iface = getInterface().value();
1599 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1602 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1603 !isa<TilingInterface>(op))
1605 if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1606 !isa<LoopLikeOpInterface>(op))
1611 if (getOpAttrs().has_value()) {
1612 DictionaryAttr opAttrs = getOpAttrs().value();
1614 if (attr.getName() == getInterfaceAttrName() ||
1615 attr.getName() == getOpsAttrName())
1617 if (!op->
hasAttr(attr.getName()))
1619 if (op->
getAttr(attr.getName()) != attr.getValue())
1624 if (getFilterResultType().has_value()) {
1625 Type t = getFilterResultType().value();
1630 if (getFilterOperandTypes().has_value()) {
1631 mlir::ArrayAttr types = getFilterOperandTypes().value();
1634 if (types.size() == 1) {
1637 dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1638 Type t = cast<::mlir::Type>(typeattr.getValue());
1640 [&](
Type operandType) { return operandType == t; }))
1645 if (types.size() != operandTypes.size()) {
1646 incorrectNumOperandTypes =
true;
1650 for (
auto [attr, operandType] :
1651 llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1652 auto typeattr = cast<mlir::TypeAttr>(attr);
1653 Type type = cast<::mlir::Type>(typeattr.getValue());
1655 if (type != operandType)
1666 (*payloadOps.begin())->
walk(matchFun);
1667 if (incorrectNumOperandTypes)
1669 "type, then it must contain as much types as "
1670 "the number of operands in the target ops");
1671 results.
set(cast<OpResult>(getResult()), res);
1686 Type &targetType,
Type &lowSizeType,
1688 Type &splitPointType) {
1689 FunctionType funcType;
1694 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1695 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
1696 "argument and one result";
1698 targetType = funcType.getInput(0);
1699 lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1707 if (isa<TransformParamTypeInterface>(getLowSize().
getType())) {
1708 if (target.hasDynamicShape()) {
1709 auto diag = emitSilenceableError()
1710 <<
"cannot compute parametric tile sizes for dynamically "
1711 "shaped payload op";
1712 diag.attachNote(target->getLoc()) <<
"payload op";
1717 target, getDimension(), getTargetSize(), getDivisor());
1719 return emitSilenceableError()
1720 <<
"failed to compute multi-size tiling sizes";
1723 Builder builder(target.getContext());
1724 results.
assign(llvm::map_range(
1726 spec->lowTileSize * spec->lowTripCount}),
1727 [&builder,
this](int64_t value) {
1739 builder, target, getDimension(), targetSize, divisor);
1741 return emitSilenceableError() <<
"could not generate tile size computation";
1748 {spec->lowTileSize, spec->lowTripCount});
1749 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1750 Operation *highTileSize = spec->highTileSize.getDefiningOp();
1751 assert(lowTileSize && highTileSize && splitPoint &&
1752 "tile sizes are not produced by operations");
1760 void transform::MultiTileSizesOp::getEffects(
1764 if (isa<TransformParamTypeInterface>(getLowSize().
getType()))
1773 return emitOpError() <<
"expects all results type to be the same";
1793 builder.
getContext(), GenericOp::getOperationName());
1794 build(builder, result,
1803 return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1810 auto targetOps = state.getPayloadOps(getTarget());
1812 if (std::empty(targetOps)) {
1813 transformResults.
set(cast<OpResult>(getPackedOp()),
1818 auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1819 if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1820 return emitSilenceableError()
1821 <<
"requires target to map to exactly 1 LinalgOp (got "
1822 << llvm::range_size(targetOps) <<
")";
1825 if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1826 return emitSilenceableError()
1827 <<
"requires number of packed sizes match the number of loops ("
1828 << getMixedPackedSizes().size() <<
" vs " << linalgOp.getNumLoops()
1835 state, *
this, packedSizes, getMixedPackedSizes());
1838 FailureOr<PackResult> maybeResult =
pack(rewriter, linalgOp, packedSizes);
1842 transformResults.
set(cast<OpResult>(getPackedOp()),
1843 {maybeResult->packedLinalgOp.getOperation()});
1847 void transform::PackOp::getEffects(
1861 return emitOpError() << getMatmulInnerDimsOrderAttrName()
1862 <<
" is not a valid permutation";
1865 if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1866 for (
auto [s, nmo] :
1867 llvm::zip_equal(getMixedMatmulPackedSizes(),
1868 getMatmulPaddedSizesNextMultipleOf())) {
1871 (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1872 return emitOpError() <<
"at most one of the packed_size and the "
1873 "padded_sizes_next_multiple_of can be nonzero "
1874 "for the matmul strategy";
1886 for (
Operation *op : state.getPayloadOps(getTarget())) {
1887 auto linalgOp = dyn_cast<LinalgOp>(op);
1898 getMixedMatmulPackedSizes(),
1900 getMatmulPaddedSizesNextMultipleOf(),
1901 getMatmulInnerDimsOrder());
1902 if (succeeded(packResult)) {
1903 results.push_back(packResult->packedLinalgOp);
1906 results.push_back(linalgOp);
1908 transformResults.
set(cast<OpResult>(getPackedOp()), results);
1914 return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1918 void transform::PackGreedilyOp::getEffects(
1932 return emitOpError() << getInnerPermAttrName()
1933 <<
" is not a valid permutation";
1936 return emitOpError() << getOuterPermAttrName()
1937 <<
" is not a valid permutation";
1939 if (getInnerPerm().empty() && getOuterPerm().empty()) {
1940 return emitOpError() <<
" at least one of " << getInnerPermAttrName()
1941 <<
" or " << getOuterPermAttrName()
1942 <<
" must be specified";
1948 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1958 template <
typename RelayoutOpTy>
1961 OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1963 llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1964 "applies to only pack or unpack operations");
1965 if (!op || permutation.empty())
1967 size_t innerRank = op.getInnerDimsPos().size();
1968 if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1972 if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1973 return permutation.size() == op.getSourceRank() &&
1976 return permutation.size() == op.getDestRank() &&
1984 auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1985 auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1987 if (std::empty(packOrUnpackOps)) {
1988 transformResults.
set(cast<OpResult>(getPackedOp()), {});
1989 transformResults.
set(cast<OpResult>(getPackOp()), {});
1990 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
1996 if (!llvm::hasSingleElement(packOrUnpackOps) ||
1997 !llvm::hasSingleElement(linalgOps)) {
1998 return emitSilenceableError()
1999 <<
"requires target to map to exactly 1 "
2000 "packing op and 1 packed op ("
2001 <<
"got " << llvm::range_size(packOrUnpackOps) <<
" and "
2002 << llvm::range_size(linalgOps) <<
")";
2006 auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
2007 auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
2008 if ((!packOp && !unPackOp)) {
2009 return emitSilenceableError() <<
"requires target to map to a "
2010 "linalg.pack or linalg.unpack";
2012 LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
2013 if (!linalgOpTarget)
2014 return emitSilenceableError() <<
"requires a LinalgOp target";
2018 if (packOp && packOp.getResult().hasOneUse())
2019 linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
2021 linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
2022 if (linalgOp != linalgOpTarget) {
2024 packOp ? StringLiteral{
"not a single use by the LinalgOp target"}
2025 : StringLiteral{
"not produced by the LinalgOp target"};
2026 return emitSilenceableError() << errorMsg;
2032 assert(!packOp &&
"packOp must be null on entry when unPackOp is not null");
2033 OpOperand *packUse = linalgOp.getDpsInitOperand(
2034 cast<OpResult>(unPackOp.getSource()).getResultNumber());
2036 if (!packOp || !packOp.getResult().hasOneUse())
2037 return emitSilenceableError() <<
"could not find matching pack op";
2041 for (
auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
2043 (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
2044 auto errorMsg = (permType == OuterOrInnerPerm::Outer)
2045 ? StringLiteral{
"invalid outer_perm"}
2046 : StringLiteral{
"invalid inner_perm"};
2050 unPackOp ? unPackOp.getOperation() : packOp.getOperation();
2051 return emitSilenceableError() << errorMsg <<
": " << *packOrUnpackOp;
2057 assert(packOp && linalgOp &&
"unexpected null op");
2061 rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
2063 assert(succeeded(res) &&
"unexpected packTranspose failure");
2066 transformResults.
set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
2067 transformResults.
set(cast<OpResult>(getPackedOp()),
2068 {res->transposedLinalgOp});
2070 transformResults.
set(cast<OpResult>(getUnPackOp()),
2071 {res->transposedUnPackOp});
2073 transformResults.
set(cast<OpResult>(getUnPackOp()), {});
2088 StringRef copyBackOp,
2089 bool usePrescribedTensorShapes) {
2099 (padToMultipleOf.empty()
2101 : b.getDenseI64ArrayAttr(padToMultipleOf)),
2102 b.getI64ArrayAttr(nofoldFlags),
2103 b.getArrayAttr(transposePaddings),
2104 b.getStringAttr(copyBackOp),
2106 usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
2114 StringRef copyBackOp,
2115 bool usePrescribedTensorShapes) {
2120 staticPadToMultipleOf);
2127 dynamicPadToMultipleOf,
2128 staticPadToMultipleOf,
2132 usePrescribedTensorShapes);
2135 void PadOp::getEffects(
2145 return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
2152 auto transformOp = cast<TransformOpInterface>(getOperation());
2155 for (
Operation *target : state.getPayloadOps(getTarget())) {
2156 auto linalgTarget = dyn_cast<LinalgOp>(target);
2157 if (!linalgTarget) {
2158 auto diag = emitSilenceableError() <<
"expected LinalgOp target";
2159 diag.attachNote(target->
getLoc()) <<
"target op";
2165 for (int64_t packPadding :
2166 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
2167 nofoldFlags.push_back(
static_cast<bool>(packPadding));
2171 for (
auto const &[untypedAttr, elementOrTensorType] :
2172 llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
2174 if (isa<ub::PoisonAttr>(untypedAttr)) {
2175 paddingValues.push_back(untypedAttr);
2178 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2180 emitOpError(
"expects padding values to be typed attributes or poison");
2185 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2189 if (!parsedAttr || parsedAttr.getType() != elementType) {
2190 auto diag = this->emitOpError(
"expects a padding that parses to ")
2191 << elementType <<
", got " << untypedAttr;
2192 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2195 paddingValues.push_back(parsedAttr);
2199 if (attr.getType() != elementType) {
2200 auto diag = this->emitOpError(
"expects a padding value of type ")
2201 << elementType <<
", got " << attr;
2202 diag.attachNote(linalgTarget.getLoc()) <<
"when applied to this op";
2205 paddingValues.push_back(attr);
2210 for (
Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2211 transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2212 cast<ArrayAttr>(transposeVector)));
2217 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2221 state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2224 if (padToMultipleOf.empty())
2228 options.padToMultipleOf = padToMultipleOf;
2229 options.paddingValues = paddingValues;
2230 options.nofoldFlags = nofoldFlags;
2231 if (getCopyBackOp() ==
2232 bufferization::MaterializeInDestinationOp::getOperationName()) {
2235 }
else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2237 }
else if (getCopyBackOp() == kCopyOpNone) {
2240 llvm_unreachable(
"unsupported copy_back op");
2243 bool irChanged =
false;
2244 if (getUsePrescribedTensorShapes() &&
2245 linalgTarget.hasPureTensorSemantics()) {
2248 for (
OpOperand &operand : linalgTarget->getOpOperands()) {
2249 for (
auto [i, dim] :
llvm::enumerate(linalgTarget.getShape(&operand))) {
2250 if (ShapedType::isStatic(dim))
2252 options.setSizeToPadTo(operand.getOperandNumber(), i,
2254 operand.get().getLoc(),
2264 replacements, newPadOps))) {
2267 diag.attachNote(target->
getLoc()) <<
"target op";
2270 auto diag = emitSilenceableError() <<
"failed to pad op";
2271 diag.attachNote(target->
getLoc()) <<
"target op";
2280 rewriter.
replaceOp(linalgTarget, replacements);
2281 paddedOps.push_back(paddedOp);
2282 padOps.append(newPadOps.begin(), newPadOps.end());
2284 for (
Value v : replacements) {
2285 Operation *copyBackOp = v.getDefiningOp();
2286 if (!llvm::is_contained(copyBackOps, copyBackOp))
2287 copyBackOps.push_back(copyBackOp);
2292 results.
set(cast<OpResult>(getPadded()), paddedOps);
2293 results.
set(cast<OpResult>(getPad()), padOps);
2294 results.
set(cast<OpResult>(getCopy()), copyBackOps);
2300 extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2301 if (any_of(nofoldFlags, [](int64_t packPadding) {
2302 return packPadding != 0 && packPadding != 1;
2304 return emitOpError()
2305 <<
"expects nofold_flags to contain booleans (0/1), found "
2306 << getNofoldFlags();
2310 extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2311 if (any_of(paddingDimensions,
2312 [](int64_t paddingDimension) {
return paddingDimension < 0; })) {
2313 return emitOpError() <<
"expects padding_dimensions to contain positive "
2315 << getPaddingDimensions();
2317 if (!getMixedPadToMultipleOf().empty()) {
2318 if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2319 return emitOpError() <<
"expects as many multiples as padding_dimensions";
2322 ArrayAttr transposes = getTransposePaddings();
2325 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2326 if (!std::is_permutation(sequence.begin(), sequence.end(),
2327 transpose.begin(), transpose.end())) {
2328 return emitOpError()
2329 <<
"expects transpose_paddings to be a permutation, found "
2333 if (getCopyBackOp() !=
2334 bufferization::MaterializeInDestinationOp::getOperationName() &&
2335 getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2336 getCopyBackOp() != kCopyOpNone)
2337 return emitOpError() <<
"invalid copy_back_op";
2345 void transform::PadTilingInterfaceOp::build(
OpBuilder &b,
2349 bool padToMultipleOf) {
2359 : b.getDenseI64ArrayAttr(paddingSizes)),
2361 padToMultipleOf ? b.getUnitAttr() : nullptr);
2364 void transform::PadTilingInterfaceOp::build(
2371 staticPaddingSizes);
2377 dynamicPaddingSizes,
2382 void transform::PadTilingInterfaceOp::getEffects(
2391 transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2393 return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2402 for (
Operation *target : state.getPayloadOps(getTarget())) {
2403 auto targetOp = dyn_cast<TilingInterface>(target);
2405 auto diag = emitSilenceableError() <<
"expected TilingInterface target";
2406 diag.attachNote(target->
getLoc()) <<
"target op";
2413 if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2414 auto diag = emitSilenceableError() <<
"only IndexingMapOpInterface ops "
2416 diag.attachNote(target->
getLoc()) <<
"target op";
2422 for (
auto const &[untypedAttr, elementOrTensorType] :
2423 llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2424 auto attr = dyn_cast<TypedAttr>(untypedAttr);
2427 if (isa<ub::PoisonAttr>(untypedAttr)) {
2428 paddingValues.push_back(untypedAttr);
2432 emitOpError(
"expects padding values to be typed attributes or poison");
2436 if (
auto stringAttr = dyn_cast<StringAttr>(attr)) {
2440 if (!parsedAttr || parsedAttr.getType() != elementType) {
2441 auto diag = this->emitOpError(
"expects a padding that parses to ")
2442 << elementType <<
", got " << attr;
2443 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2446 paddingValues.push_back(parsedAttr);
2450 if (attr.getType() != elementType) {
2451 auto diag = this->emitOpError(
"expects a padding value of type ")
2452 << elementType <<
", got " << attr;
2453 diag.attachNote(targetOp.getLoc()) <<
"when applied to this op";
2456 paddingValues.push_back(attr);
2460 TilingInterface paddedOp;
2462 options.setPaddingValues(paddingValues)
2463 .setPaddingSizes(getMixedPaddingSizes())
2464 .setPadToMultipleOf(getPadToMultipleOf());
2469 rewriter, cast<TilingInterface>(targetOp.getOperation()),
options,
2471 if (
failed(maybePaddedOp)) {
2472 auto diag = emitSilenceableError() <<
"failed to pad op";
2473 diag.attachNote(target->
getLoc()) <<
"target op";
2478 paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
2479 padOps.append(newPadOps.begin(), newPadOps.end());
2482 results.
set(cast<OpResult>(getPadded()), paddedOps);
2483 results.
set(cast<OpResult>(getPad()), padOps);
2497 auto targetOps = state.getPayloadOps(getTarget());
2498 auto loopOps = state.getPayloadOps(getLoop());
2499 if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2501 <<
"requires exactly one target and one loop handle (got "
2502 << llvm::range_size(targetOps) <<
" and "
2503 << llvm::range_size(loopOps) <<
")";
2506 auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2507 auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2508 if (!padOp || !loopOp)
2511 FailureOr<linalg::detail::PackingResult> result =
2517 if (result->clonedLoopIvs.empty()) {
2518 transformResults.
set(cast<OpResult>(getPackingLoop()),
2519 {result->hoistedPadOp.getOperation()});
2522 auto outerPackedLoop =
2524 transformResults.
set(cast<OpResult>(getPackingLoop()),
2525 {outerPackedLoop.getOperation()});
2531 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2532 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2534 return emitOpError() <<
"expects transpose to be a permutation, found "
2540 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2550 tensor::PadOp target,
2553 tensor::PadOp hoistedPadOp;
2555 FailureOr<Value> result =
2557 hoistedPadOp, transposeOps);
2558 if (succeeded(result)) {
2568 return emitDefaultSilenceableFailure(target);
2573 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2574 if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2576 return emitOpError() <<
"expects transpose to be a permutation, found "
2592 if (!getOperandsToPromote().empty())
2594 extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2595 if (getUseFullTilesByDefault())
2597 getUseFullTilesByDefault());
2598 if (getUseOriginalSubviewSize())
2602 promotionOptions = promotionOptions.
setUseAlloca(getUseAlloca());
2603 if (!getUseFullTileBuffers().empty())
2605 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2606 if (getAlignment().has_value())
2607 promotionOptions = promotionOptions.
setAlignment(*getAlignment());
2608 if (getMemorySpace().has_value())
2609 promotionOptions = promotionOptions.
setMemorySpace(*getMemorySpace());
2611 if (getMapping().has_value()) {
2613 auto mapping = *getMapping();
2614 if (mapping.size() > 1)
2615 return emitDefaultDefiniteFailure(target);
2617 auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2619 if (addressSpace.getAddressSpace() ==
2620 mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2627 }
else if (addressSpace.getAddressSpace() ==
2628 mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2636 return emitDefaultDefiniteFailure(target);
2641 return emitDefaultDefiniteFailure(target);
2644 FailureOr<LinalgOp> res =
promoteSubViews(rewriter, target, promotionOptions);
2646 return emitDefaultDefiniteFailure(target);
2659 auto payload = state.getPayloadOps(getTarget());
2663 if (target->getNumOperands() > 0)
2666 target->getNumRegions() > 0)
2668 <<
"expected target that is isolated from above";
2672 Operation *pattern = &getBodyRegion().front().front();
2675 if (getOperation()->isAncestor(target))
2680 replacements.push_back(replacement);
2682 transformResults.
set(cast<OpResult>(getReplacement()), replacements);
2686 void transform::ReplaceOp::getEffects(
2694 if (!getBodyRegion().hasOneBlock())
2695 return emitOpError() <<
"expected one block";
2696 if (std::distance(getBodyRegion().front().begin(),
2697 getBodyRegion().front().end()) != 1)
2698 return emitOpError() <<
"expected one operation in block";
2699 Operation *replacement = &getBodyRegion().front().front();
2702 <<
"expected replacement without operands";
2706 <<
"expect op that is isolated from above";
2724 target.createFlatListOfOperandDims(b, loc);
2725 AffineMap map = target.getShapesToLoopsMap();
2740 FailureOr<scf::SCFTilingResult> maybeTilingResult =
tileUsingSCF(
2741 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2742 if (
failed(maybeTilingResult))
2743 return emitDefaultDefiniteFailure(target);
2745 if (target->getNumResults())
2746 rewriter.
replaceOp(target, maybeTilingResult->replacements);
2750 results.
reserve(maybeTilingResult->tiledOps.size());
2751 for (
Operation *tiled : maybeTilingResult->tiledOps)
2765 for (
Operation *target : state.getPayloadOps(getTarget())) {
2766 auto tilingOp = dyn_cast<TilingInterface>(*target);
2769 emitSilenceableError()
2770 <<
"expected the payload to implement TilingInterface";
2771 diag.attachNote(target->getLoc()) <<
"payload op";
2775 FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2777 if (
failed(generatedLoops))
2778 return emitDefaultDefiniteFailure(target);
2779 for (scf::ForOp &loop : *generatedLoops) {
2780 loops.push_back(loop.getOperation());
2784 results.
set(cast<OpResult>(getResult()), loops);
2793 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2798 FailureOr<Operation *> maybeResult =
2800 .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2801 [&rewriter](
auto op) {
2805 return emitDefaultSilenceableFailure(target);
2819 llvm::to_vector(state.getPayloadOps(getTarget()));
2821 bool isMultiwaySplit = getMultiway();
2823 if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2825 <<
"requires exactly one target when "
2826 "multiway split is enabled (got "
2827 << llvm::range_size(payload) <<
")";
2832 if (!isMultiwaySplit)
2833 chunkSizes.reserve(payload.size());
2835 if (getDynamicChunkSizes()) {
2837 if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().
getType())) {
2838 chunkSizes = llvm::to_vector(llvm::map_range(
2839 state.getPayloadOps(getDynamicChunkSizes()), [&](
Operation *op) {
2842 diag = emitSilenceableError()
2843 <<
"expected dynamic split point handle to point to a "
2844 "single-result index-typed op";
2845 diag.attachNote(op->getLoc()) <<
"dynamic split point";
2850 chunkSizes = llvm::to_vector(
2851 llvm::map_range(state.getParams(getDynamicChunkSizes()),
2854 if (
diag.isSilenceableFailure())
2859 if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2861 <<
"expected the dynamic split point handle to point to as "
2863 << chunkSizes.size() <<
") as the target handle ("
2864 << payload.size() <<
")";
2867 chunkSizes.resize(payload.size(),
2871 auto checkStructuredOpAndDimensions =
2874 auto diag = emitSilenceableError() <<
"only applies to structured ops";
2875 diag.attachNote(loc) <<
"target op";
2879 if (getDimension() >= linalgOp.getNumLoops()) {
2880 auto diag = emitSilenceableError() <<
"dimension " << getDimension()
2881 <<
" does not exist in target op";
2882 diag.attachNote(loc) <<
"target op";
2888 auto checkFailureInSplitting =
2892 diag.attachNote(loc) <<
"target op";
2899 if (isMultiwaySplit) {
2902 TilingInterface head, tail;
2905 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2909 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2910 if (
diag.isSilenceableFailure())
2916 target = tail.getOperation();
2921 linalgOp = cast<LinalgOp>(target);
2926 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2927 getDimension(), chunkSize);
2931 checkFailureInSplitting(!head && !tail, loc);
2932 if (
diag.isDefiniteFailure())
2935 opList.push_back(head.getOperation());
2940 opList.push_back(tail.getOperation());
2946 for (
const auto &pair : llvm::zip(payload, chunkSizes)) {
2949 LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2951 checkStructuredOpAndDimensions(linalgOp, target->
getLoc());
2953 if (
diag.isSilenceableFailure())
2957 std::tie(first.emplace_back(), second.emplace_back()) =
linalg::splitOp(
2958 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2959 getDimension(), std::get<1>(pair));
2963 checkFailureInSplitting(!first.back() && !second.back(), loc);
2968 if (!second.back()) {
2969 noSecondPart = target;
2974 if (second.size() != first.size() && !second.empty()) {
2975 auto diag = emitSilenceableError()
2976 <<
"splitting does not produce the second part for a subset "
2979 <<
"expected splitting to produce the second part of all "
2980 "or none of the targets";
2982 <<
"first target with no second part";
2986 opList.append(first);
2987 if (!second.empty())
2988 opList.append(second);
2990 results.
set(cast<OpResult>(getSplitList()), opList);
2994 void SplitOp::getEffects(
2997 if (getDynamicChunkSizes())
3005 IntegerAttr staticChunkSizes;
3011 if (!dynamicPointParseResult.
has_value()) {
3012 int64_t staticChunkSizesValue;
3026 if (dynamicPointParseResult.
has_value()) {
3027 Type ChunkSizesType;
3040 SplitOp::getStaticChunkSizesAttrName(result.
name).getValue(),
3047 printer <<
" " << getTarget() <<
" after ";
3048 int64_t staticChunkSize =
static_cast<int64_t
>(getStaticChunkSizes());
3049 if (staticChunkSize != ShapedType::kDynamic)
3050 printer << staticChunkSize;
3052 printer << getDynamicChunkSizes();
3055 {getStaticChunkSizesAttrName()});
3056 printer <<
" : " << getTarget().getType();
3057 if (staticChunkSize == ShapedType::kDynamic)
3058 printer <<
", " << getDynamicChunkSizes().getType();
3062 if ((
static_cast<int64_t
>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
3063 (getDynamicChunkSizes() ==
nullptr)) {
3064 return emitOpError() <<
"expects either a dynamic or a static split "
3065 "point to be provided";
3074 void transform::SplitReductionOp::build(
3076 int64_t splitFactor, int64_t insertSplitDimension,
bool innerParallel,
3077 bool useScalingAlgorithm,
bool useAlloc) {
3083 SplitReductionOp::getInsertSplitDimensionAttrName(result.
name),
3085 if (innerParallel) {
3086 result.
addAttribute(SplitReductionOp::getInnerParallelAttrName(result.
name),
3089 if (useScalingAlgorithm) {
3091 SplitReductionOp::getUseScalingAlgorithmAttrName(result.
name),
3099 result.
addTypes({resultType, resultType, resultType, resultType});
3108 unsigned(getInsertSplitDimension()),
3109 bool(getInnerParallel())};
3112 FailureOr<SplitReductionResult> splitResult =
3113 (getUseScalingAlgorithm())
3117 return emitDefaultDefiniteFailure(target);
3119 results.
push_back(splitResult->initOrAlloc);
3121 results.
push_back(splitResult->splitLinalgOp);
3122 results.
push_back(splitResult->resultCombiningLinalgOp);
3130 void transform::TileReductionUsingForOp::build(
3141 build(builder, result,
3145 staticTileSizesAttr);
3154 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3155 if (!partialReductionOp) {
3158 "Operation should implement PartialReductionOpInterface");
3162 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3163 if (reductionDims.empty()) {
3164 for (
auto [idx, iteratorType] :
3166 if (iteratorType == utils::IteratorType::reduction)
3167 reductionDims.push_back(idx);
3173 options.setReductionTilingStrategy(
3176 options.setReductionDims(reductionDims);
3177 FailureOr<scf::SCFTilingResult> result =
3182 "failed to tile using partial reduction");
3184 rewriter.
replaceOp(target, result->replacements);
3185 for (
Value initValue : result->initialValues)
3187 for (
auto parallelTiledOp : result->tiledOps)
3189 for (
auto mergeOp : result->mergeOps)
3191 results.
push_back(result->loops.front());
3199 void transform::TileReductionUsingForallOp::build(
3202 ArrayAttr mapping) {
3212 build(builder, result,
3216 staticNumThreadsAttr,
3217 staticTileSizesAttr,
3227 auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3228 if (!partialReductionOp) {
3231 "Operation should implement PartialReductionOpInterface");
3240 options.setReductionTilingStrategy(
3242 if (!getNumThreads().empty()) {
3243 options.setNumThreads(numThreads);
3245 options.setTileSizes(tileSizes);
3247 if (
auto mapping = getMapping()) {
3248 options.setMapping(mapping.value().getValue());
3251 extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3252 if (reductionDims.empty()) {
3253 for (
auto [idx, iteratorType] :
3255 if (iteratorType == utils::IteratorType::reduction)
3256 reductionDims.push_back(idx);
3259 options.setReductionDims(reductionDims);
3260 FailureOr<scf::SCFTilingResult> result =
3264 auto diag = emitSilenceableError() <<
"could not tile reduction";
3267 rewriter.
replaceOp(target, result->replacements);
3269 for (
Value initValue : result->initialValues)
3271 for (
auto parallelTiledOp : result->tiledOps)
3273 for (
auto mergeOp : result->mergeOps)
3275 results.
push_back(result->loops.front());
3289 llvm::to_vector(state.getPayloadOps(getTarget()));
3291 if (!llvm::hasSingleElement(targetOps)) {
3293 <<
"requires exactly one target (got " << llvm::range_size(targetOps)
3298 auto linalgOp = dyn_cast<LinalgOp>(target);
3299 auto tileableOp = dyn_cast<TilingInterface>(target);
3304 OpBuilder builder(linalgOp.getContext());
3306 if (isa<TransformParamTypeInterface>(getChunkSizes().
getType())) {
3307 if (linalgOp.hasDynamicShape()) {
3308 auto diag = emitSilenceableError()
3309 <<
"cannot compute parametric tile sizes for dynamically "
3310 "shaped payload op";
3311 diag.attachNote(linalgOp->getLoc()) <<
"payload op";
3315 FailureOr<StaticContinuousTileSizeSpecification> spec =
3319 return emitSilenceableError()
3320 <<
"failed to compute multi-size tiling sizes";
3325 for (
auto &&[tileSize, tripCount] :
3326 llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3327 chunkSizes.push_back(tileSize * tripCount);
3330 return llvm::map_to_vector(values, [&](int64_t value) ->
Attribute {
3335 getI64AttrsFromI64(spec->tileSizes));
3336 transformResults.
setParams(cast<OpResult>(getChunkSizes()),
3337 getI64AttrsFromI64(chunkSizes));
3345 unsigned dimension = getDimension();
3348 builder, tileableOp, dimension, targetSize,
true);
3350 return emitSilenceableError() <<
"could not generate tile size computation";
3362 for (
auto &&[tileSize, tripCount] :
3363 llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3364 splitPoint = apply(s0 * s1, {tileSize, tripCount});
3365 chunkSizes.push_back(splitPoint);
3369 return llvm::map_to_vector(values, [&](
Value value) ->
Operation * {
3375 getDefiningOps(spec->tileSizes));
3376 transformResults.
set(cast<OpResult>(getChunkSizes()),
3377 getDefiningOps(chunkSizes));
3385 return emitOpError() <<
"expects all results type to be the same";
3391 void transform::ContinuousTileSizesOp::getEffects(
3409 Type &tileSizesType,
3410 Type &chunkSizesType) {
3411 FunctionType funcType;
3416 if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3417 parser.
emitError(typeLoc) <<
"expects a trailing functional type with one "
3418 "argument and one result";
3420 targetType = funcType.getInput(0);
3421 tileSizesType = chunkSizesType = funcType.getResult(0);
3430 void transform::TileUsingForOp::build(
3435 return build(builder, result, loopTypes,
3439 interchange, scalableSizes);
3442 void transform::TileUsingForOp::build(
3446 build(builder, result, target,
3448 interchange, scalableSizes);
3451 void transform::TileUsingForOp::build(
3458 build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3462 void transform::TileUsingForOp::build(
3474 unsigned numExpectedLoops =
3475 staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3477 resultTypes.reserve(numExpectedLoops);
3478 assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3479 "expected one loop type or as many as loops");
3480 if (loopTypes.size() == 1)
3481 resultTypes.append(numExpectedLoops, loopTypes[0]);
3483 llvm::append_range(resultTypes, loopTypes);
3485 if (scalableSizes.has_value())
3486 expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3487 build(builder, result, target.
getType(),
3491 staticTileSizesAttr,
3493 expandedScalableSizes);
3498 return emitOpError(
"expected same number of sizes (")
3500 << getScalableSizes().size() <<
")";
3502 unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3503 if (getLoops().size() != numExpectedLoops)
3504 return emitOpError(
"expected number of loops to tile (")
3505 << numExpectedLoops <<
") to match number of `loops` results ("
3506 << getLoops().size() <<
")";
3517 llvm::to_vector(state.getPayloadOps(getTarget()));
3523 if (isa<ParamType>(transformValue.getType())) {
3524 dynamicSizeProducers.push_back({});
3526 paramSizes.push_back(
3527 llvm::to_vector(llvm::map_range(params, [](
Attribute attr) {
3528 return cast<IntegerAttr>(attr).getValue().getSExtValue();
3531 if (paramSizes.back().size() != targets.size()) {
3533 emitSilenceableError()
3534 <<
"expected as many parameter values ("
3535 << dynamicSizeProducers.back().size() <<
") as target ops ("
3536 << targets.size() <<
")";
3537 diag.attachNote(transformValue.getLoc()) <<
"for this parameter";
3543 paramSizes.push_back({});
3544 dynamicSizeProducers.push_back(
3545 llvm::to_vector(state.getPayloadOps(transformValue)));
3547 if (dynamicSizeProducers.back().size() != targets.size()) {
3549 emitSilenceableError()
3550 <<
"expected as many dynamic size-producing operations ("
3551 << dynamicSizeProducers.back().size() <<
") as target ops ("
3552 << targets.size() <<
")";
3553 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3557 for (
Operation *op : dynamicSizeProducers.back()) {
3564 emitSilenceableError() <<
"expected sizes to be produced by ops "
3565 "with a single index-type result";
3566 diag.attachNote(op->
getLoc()) <<
"size producer op";
3567 diag.attachNote(transformValue.getLoc()) <<
"for this handle";
3574 loops.resize(getLoops().size());
3575 auto scalableSizes = getScalableSizes();
3577 auto tilingInterface = dyn_cast<TilingInterface>(op);
3578 if (!tilingInterface) {
3580 emitSilenceableError()
3581 <<
"only ops implementing TilingInterface are supported";
3582 diag.attachNote(op->
getLoc()) <<
"target op";
3585 if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3587 emitSilenceableError()
3588 <<
"too many tiles provided, expected at most "
3589 << tilingInterface.getLoopIteratorTypes().size() <<
" found "
3590 << tileSizes.size();
3591 diag.attachNote(op->
getLoc()) <<
"target op";
3596 if (tileSizes.empty()) {
3605 sizes.reserve(tileSizes.size());
3606 unsigned dynamicIdx = 0;
3609 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3610 if (scalableSizes[ofrIdx]) {
3612 b, getLoc(), cast<IntegerAttr>(attr).getInt());
3614 vector::VectorScaleOp::create(b, getLoc(), b.
getIndexType());
3616 arith::MulIOp::create(b, getLoc(), val, vscale).getResult());
3618 sizes.push_back(attr);
3625 assert((dynamicSizes.empty() ^ params.empty()) &&
3626 "expected either dynamic sizes or parameters");
3627 if (!params.empty()) {
3630 sizes.push_back(dynamicSizes[index]->getResult(0));
3638 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3639 tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3640 if (
failed(maybeTilingResult))
3643 rewriter.
replaceOp(op, maybeTilingResult->replacements);
3645 tiled.append(maybeTilingResult->tiledOps);
3647 loops[en2.index()].push_back(en2.value());
3650 transformResults.
set(cast<OpResult>(getTiledLinalgOp()), tiled);
3652 transformResults.
set(cast<OpResult>(getLoops()[en.index()]), en.value());
3661 results.reserve(tileSizes.size());
3662 unsigned dynamicPos = 0;
3664 for (int64_t size : tileSizes) {
3665 if (size == ShapedType::kDynamic) {
3666 results.push_back(dynamic[dynamicPos++]);
3674 void transform::TileUsingForOp::getEffects(
3686 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3690 ArrayAttr mapping) {
3691 return build(builder, result,
3699 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3703 ArrayAttr mapping) {
3713 build(builder, result,
3714 TypeRange{operationType, operationType},
3721 staticTileSizesAttr,
3725 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3729 ArrayAttr mapping) {
3730 return build(builder, result, target,
3735 void transform::TileUsingForallOp::build(
OpBuilder &builder,
3739 ArrayAttr mapping) {
3750 build(builder, result,
3751 TypeRange{operationType, operationType},
3757 staticNumThreadsAttr,
3770 AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3772 for (
auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3774 rewriter, loc, normalizedUbExpr, {lb, ub, step});
3775 normalizedUbs.push_back(normalizedUb);
3777 return normalizedUbs;
3793 for (
auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3796 denormalizedIvs.push_back(
3799 return denormalizedIvs;
3810 scf::ForallOp loop) {
3827 auto normalizedForallOp = scf::ForallOp::create(
3828 rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps,
3829 loop.getOutputs(), loop.getMapping(),
3832 auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3834 Block *normalizedLoopBlock = normalizedForallOp.getBody();
3839 argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3840 normalizedForallOp.getRegionIterArgs().end());
3841 Block *origLoopBlock = loop.getBody();
3842 rewriter.
mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3844 rewriter.
replaceOp(loop, normalizedForallOp);
3845 return normalizedForallOp;
3850 TransformOpInterface transformOp,
Operation *target,
3855 auto tileableOp = dyn_cast<TilingInterface>(target);
3858 transformOp.emitSilenceableError()
3859 <<
"only TilingInterface ops are supported";
3860 diag.attachNote(target->
getLoc()) <<
"target op";
3866 if (!mixedNumThreads.empty()) {
3867 options.setNumThreads(mixedNumThreads);
3869 options.setTileSizes(mixedTileSizes);
3872 options.setMapping(mapping.value().getValue());
3874 FailureOr<scf::SCFTilingResult> maybeTilingResult =
3877 if (
failed(maybeTilingResult))
3878 return transformOp.emitDefaultSilenceableFailure(tileableOp);
3880 rewriter.
replaceOp(tileableOp, maybeTilingResult->replacements);
3882 tilingResult = *maybeTilingResult;
3884 if (mixedNumThreads.empty()) {
3885 auto generatedForallOp = cast<scf::ForallOp>(tilingResult.
loops.front());
3888 scf::ForallOp normalizedForallOp =
3890 tilingResult.
loops.front() = normalizedForallOp;
3900 auto transformOp = cast<TransformOpInterface>(getOperation());
3909 getPackedNumThreads()
3911 state, transformOp, mixedNumThreads, getPackedNumThreads())
3913 state, transformOp, mixedNumThreads, getMixedNumThreads());
3917 status = getPackedTileSizes()
3919 state, transformOp, mixedTileSizes, getPackedTileSizes())
3921 state, transformOp, mixedTileSizes, getMixedTileSizes());
3925 for (
Operation *target : state.getPayloadOps(getTarget())) {
3928 rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3929 getMapping(), tilingResult);
3930 if (!
diag.succeeded())
3932 tileOps.push_back(tilingResult.
loops.front());
3933 tiledOps.append(tilingResult.
tiledOps);
3936 transformResults.
set(cast<OpResult>(getForallOp()), tileOps);
3937 transformResults.
set(cast<OpResult>(getTiledOp()), tiledOps);
3942 void transform::TileUsingForallOp::getEffects(
3955 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3964 int numThreadsSpec =
static_cast<int>(!getMixedNumThreads().empty()) +
3965 static_cast<int>(getPackedNumThreads() !=
Value());
3966 if (numThreadsSpec > 1)
3968 "num_threads and packed_num_threads are mutually exclusive");
3969 int tileSizesSpec =
static_cast<int>(!getMixedTileSizes().empty()) +
3970 static_cast<int>(getPackedTileSizes() !=
Value());
3971 if (tileSizesSpec > 1)
3973 "tile_sizes and packed_tile_sizes are mutually exclusive");
3974 if (numThreadsSpec == 0 && tileSizesSpec == 0)
3975 return emitOpError(
"either (packed_)num_threads or (packed_)tile_sizes "
3976 "must be specified");
3984 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3986 bool foldTypeExtensionsIntoContract,
bool vectorizePadding,
3987 bool vectorizeExtract,
bool flatten1DDepthwiseConv) {
3989 if (foldTypeExtensionsIntoContract) {
3991 VectorizeChildrenAndApplyPatternsOp::
3992 getFoldTypeExtensionsIntoContractAttrName(result.
name),
3995 if (vectorizePadding) {
3997 VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
4001 if (vectorizeExtract) {
4003 VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
4007 if (flatten1DDepthwiseConv) {
4009 VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
4020 explicit VectorizationPattern(
MLIRContext *context,
4021 bool vectorizeExtract =
false,
4022 bool flattenConv =
false)
4024 vectorizeNDExtract(vectorizeExtract),
4025 flatten1DDepthwiseConv(flattenConv) {}
4026 LogicalResult matchAndRewrite(
Operation *op,
4030 "Unsupported Op, cannot vectorize");
4031 FailureOr<VectorizationResult> vectorResults =
4033 {}, vectorizeNDExtract,
4034 flatten1DDepthwiseConv);
4035 if (
failed(vectorResults))
4037 rewriter.
replaceOp(op, vectorResults->replacements);
4044 bool vectorizeNDExtract =
false;
4048 bool flatten1DDepthwiseConv =
false;
4053 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
4058 auto diag = this->emitOpError(
"requires isolated-from-above targets");
4059 diag.attachNote(target->
getLoc()) <<
"non-isolated target";
4065 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
4066 getFlatten_1dDepthwiseConv());
4068 if (!getDisableTransferPermutationMapLoweringPatterns())
4071 if (!getDisableMultiReductionToContractPatterns())
4079 vector::TransferReadOp::getCanonicalizationPatterns(
patterns, ctx);
4080 vector::TransferWriteOp::getCanonicalizationPatterns(
patterns, ctx);
4085 if (getFoldTypeExtensionsIntoContract())
4088 if (getVectorizePadding()) {
4100 return emitDefaultDefiniteFailure(target);
4114 auto targets = state.getPayloadOps(getTarget());
4115 if (std::empty(targets))
4117 auto transformOp = cast<TransformOpInterface>(getOperation());
4120 state, transformOp, getMixedVectorSizes(), vectorSizes);
4128 <<
"Unsupported Op, cannot vectorize";
4130 FailureOr<VectorizationResult> vectorResults =
4132 getVectorizeNdExtract().value_or(
false),
4134 getAssumeDynamicDimsMatchVecSizes().value_or(
false),
4135 getCreateNamedContraction().value_or(
false));
4136 if (
failed(vectorResults)) {
4138 <<
"Attempted to vectorize, but failed";
4140 rewriter.
replaceOp(target, vectorResults->replacements);
4146 void transform::VectorizeOp::getEffects(
4155 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
4159 if (getStaticVectorSizes().size() != getScalableSizes().size())
4160 return emitOpError(
"expected same number of vector sizes (")
4161 << getStaticVectorSizes().size() <<
") and scalable sizes ("
4162 << getScalableSizes().size() <<
")";
4171 transform::HoistRedundantVectorTransfersOp::applyToOne(
4188 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
4207 auto maybeTransformed =
4210 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
4213 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4216 .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
4219 .Case([&](linalg::Conv2DNchwFchwOp op) {
4225 if (
failed(maybeTransformed))
4226 return emitDefaultSilenceableFailure(target);
4228 results.
push_back(maybeTransformed->first);
4230 results.
push_back(maybeTransformed->second);
4245 <<
"only elementwise flattening is supported";
4248 if (target.getNumLoops() <= 1) {
4255 std::iota(reassociation.begin(), reassociation.end(), 0);
4256 auto maybeFlattened =
4258 if (
failed(maybeFlattened))
4260 <<
"attempted to flatten, but failed";
4261 results.
push_back(maybeFlattened->collapsedOp);
4262 rewriter.
replaceOp(target, maybeFlattened->results);
4275 auto maybeTransformed =
4277 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4280 .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4286 if (
failed(maybeTransformed))
4287 return emitDefaultSilenceableFailure(target);
4302 bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4303 auto maybeTransformed =
4305 .Case([&](linalg::MatmulOp op) {
4308 .Case([&](linalg::BatchMatmulOp op) {
4311 .Default([&](
Operation *op) {
return failure(); });
4312 if (
failed(maybeTransformed))
4322 template <
typename OpTy>
4326 static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4327 tensor::ParallelInsertSliceOp>() &&
4330 if (
auto copySource =
4331 target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4339 if (isa<mlir::ParallelCombiningOpInterface>(target.getOperation()))
4342 Value extracted = tensor::ExtractSliceOp::create(
4343 rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(),
4344 target.getMixedSizes(), target.getMixedStrides());
4345 Value copied = linalg::CopyOp::create(rewriter, target.getLoc(),
4346 target.getSource(), extracted)
4351 target, copied, target.getDest(), target.getMixedOffsets(),
4352 target.getMixedSizes(), target.getMixedStrides());
4364 if (
auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4365 return doit(rewriter, target, results, state);
4366 if (
auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4367 return doit(rewriter, target, results, state);
4370 emitSilenceableError()
4371 <<
"only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4372 diag.attachNote(targetOp->
getLoc()) <<
"target op";
4385 if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4387 emitSilenceableError()
4388 <<
"only linalg.copy and tensor.pad target ops are supported";
4389 diag.attachNote(target->
getLoc()) <<
"target op";
4392 assert(target->
getNumResults() == 1 &&
"expected single result");
4394 if (!resultShapedType.hasStaticShape()) {
4396 emitSilenceableError()
4397 <<
"only statically sized ops of rank <= 3 are supported";
4398 diag.attachNote(target->
getLoc()) <<
"target op";
4403 int64_t desiredBitAlignment = getDesiredBitAlignment();
4404 int64_t eltBitwidth =
4405 resultShapedType.getElementType().getIntOrFloatBitWidth();
4406 if (desiredBitAlignment % eltBitwidth != 0) {
4407 desiredBitAlignment = eltBitwidth;
4412 getTotalNumThreads(),
4413 desiredBitAlignment,
4414 resultShapedType.getShape(),
4417 resultShapedType.getElementType().getIntOrFloatBitWidth());
4418 if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4420 emitSilenceableError()
4421 <<
"too few threads to map copy op to threads on the most minor "
4422 "dimension, given alignment and vector size constraints, try "
4423 "smaller tile size of mapping to more threads";
4424 diag.attachNote(target->
getLoc()) <<
"target op";
4440 if (!
diag.succeeded())
4444 for (
auto op : tilingResult.
tiledOps)
4458 FailureOr<Operation *> maybeTransformed = failure();
4460 .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4465 .Default([&](
Operation *op) {
return false; });
4468 return emitSilenceableError()
4469 <<
"this operation is not supported to convert to Winograd Conv2D";
4472 if (
failed(maybeTransformed)) {
4473 return emitSilenceableError() <<
"apply Winograd Conv2D failed";
4485 FailureOr<Operation *> maybeTransformed = failure();
4488 .Case([&](linalg::WinogradFilterTransformOp op) {
4492 .Case([&](linalg::WinogradInputTransformOp op) {
4496 .Case([&](linalg::WinogradOutputTransformOp op) {
4500 .Default([&](
Operation *op) {
return false; });
4504 emitSilenceableError()
4505 <<
"this operation is not supported to decompose into other operations";
4506 diag.attachNote(target->
getLoc()) <<
"target op";
4510 if (
failed(maybeTransformed)) {
4512 emitSilenceableError() <<
"decompose Winograd operations failed";
4513 diag.attachNote(target->
getLoc()) <<
"target op";
4521 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4523 #define GET_OP_CLASSES
4524 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
This class represents a saved insertion point.
bool isSet() const
Returns true if this insert point is set.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getOpResult(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
void setOperand(unsigned idx, Value value)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
user_range getUsers()
Returns a range of all users.
result_range getOpResults()
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
user_range getUsers() const
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
ForwardingListener(OpBuilder::Listener *listener)
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
@ MaterializeInDestination
Transformation to drop unit-extent dimensions from linalg.generic operations.
Vectorization pattern for memref::CopyOp.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Match and rewrite for the pattern:
Match and rewrite for the pattern:
@ BufferizationMaterializeInDestination
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
SCFTilingOptions & setLoopType(LoopType type)
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.