22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24#include "llvm/Support/FormatVariadic.h"
43 VectorType distributedType) {
49 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
54 distributedType.getContext());
62 VectorType distributedType) {
63 assert(sequentialType.getRank() == distributedType.getRank() &&
64 "sequential and distributed vector types must have the same rank");
66 for (
int64_t i = 0; i < sequentialType.getRank(); ++i) {
67 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
70 assert(distributedDim == -1 &&
"found multiple distributed dims");
74 return distributedDim;
84struct DistributedLoadStoreHelper {
85 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
86 Value laneId, Value zero)
87 : sequentialVal(sequentialVal), distributedVal(distributedVal),
88 laneId(laneId), zero(zero) {
89 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
90 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
91 if (sequentialVectorType && distributedVectorType)
96 Value buildDistributedOffset(RewriterBase &
b, Location loc, int64_t index) {
97 int64_t distributedSize = distributedVectorType.getDimSize(index);
99 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
100 ArrayRef<Value>{laneId});
110 Operation *buildStore(RewriterBase &
b, Location loc, Value val,
112 assert((val == distributedVal || val == sequentialVal) &&
113 "Must store either the preregistered distributed or the "
114 "preregistered sequential value.");
116 if (!isa<VectorType>(val.
getType()))
117 return memref::StoreOp::create(
b, loc, val, buffer, zero);
121 int64_t rank = sequentialVectorType.getRank();
122 SmallVector<Value>
indices(rank, zero);
123 if (val == distributedVal) {
124 for (
auto dimExpr : distributionMap.getResults()) {
125 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126 indices[index] = buildDistributedOffset(
b, loc, index);
129 SmallVector<bool> inBounds(
indices.size(),
true);
130 return vector::TransferWriteOp::create(
132 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
155 Value buildLoad(RewriterBase &
b, Location loc, Type type, Value buffer) {
158 if (!isa<VectorType>(type))
159 return memref::LoadOp::create(
b, loc, buffer, zero);
164 assert((type == distributedVectorType || type == sequentialVectorType) &&
165 "Must store either the preregistered distributed or the "
166 "preregistered sequential type.");
167 SmallVector<Value>
indices(sequentialVectorType.getRank(), zero);
168 if (type == distributedVectorType) {
169 for (
auto dimExpr : distributionMap.getResults()) {
170 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171 indices[index] = buildDistributedOffset(
b, loc, index);
174 SmallVector<bool> inBounds(
indices.size(),
true);
175 return vector::TransferReadOp::create(
176 b, loc, cast<VectorType>(type), buffer,
indices,
178 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
181 Value sequentialVal, distributedVal, laneId, zero;
182 VectorType sequentialVectorType, distributedVectorType;
183 AffineMap distributionMap;
196 return rewriter.
create(res);
230 WarpOpToScfIfPattern(MLIRContext *context,
231 const WarpExecuteOnLane0LoweringOptions &options,
232 PatternBenefit benefit = 1)
233 : WarpDistributionPattern(context, benefit), options(options) {}
235 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
236 PatternRewriter &rewriter)
const override {
237 assert(warpOp.getBodyRegion().hasOneBlock() &&
238 "expected WarpOp with single block");
239 Block *warpOpBody = &warpOp.getBodyRegion().front();
240 Location loc = warpOp.getLoc();
243 OpBuilder::InsertionGuard g(rewriter);
248 Value isLane0 = arith::CmpIOp::create(
249 rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
250 auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
252 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
256 SmallVector<Value> bbArgReplacements;
257 for (
const auto &it : llvm::enumerate(warpOp.getArgs())) {
258 Value sequentialVal = warpOpBody->
getArgument(it.index());
259 Value distributedVal = it.value();
260 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
261 warpOp.getLaneid(), c0);
265 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
268 helper.buildStore(rewriter, loc, distributedVal, buffer);
271 bbArgReplacements.push_back(
272 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
276 if (!warpOp.getArgs().empty()) {
278 options.warpSyncronizationFn(loc, rewriter, warpOp);
282 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
288 SmallVector<Value> replacements;
289 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
290 Location yieldLoc = yieldOp.getLoc();
291 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
292 Value sequentialVal = it.value();
293 Value distributedVal = warpOp->getResult(it.index());
294 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
295 warpOp.getLaneid(), c0);
299 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
305 helper.buildStore(rewriter, loc, sequentialVal, buffer);
316 replacements.push_back(
317 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
321 if (!yieldOp.getOperands().empty()) {
323 options.warpSyncronizationFn(loc, rewriter, warpOp);
329 scf::YieldOp::create(rewriter, yieldLoc);
332 rewriter.
replaceOp(warpOp, replacements);
338 const WarpExecuteOnLane0LoweringOptions &options;
351static VectorType getDistributedType(VectorType originalType,
AffineMap map,
359 if (targetShape[position] % warpSize != 0) {
360 if (warpSize % targetShape[position] != 0) {
363 warpSize /= targetShape[position];
364 targetShape[position] = 1;
367 targetShape[position] = targetShape[position] / warpSize;
374 VectorType targetType =
375 VectorType::get(targetShape, originalType.getElementType());
385getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp,
Region &innerRegion,
387 llvm::SmallSetVector<Value, 32> escapingValues;
390 if (innerRegion.
empty())
391 return {std::move(escapingValues), std::move(escapingValueTypes),
392 std::move(escapingValueDistTypes)};
395 if (warpOp->isAncestor(parent)) {
396 if (!escapingValues.insert(operand->
get()))
399 if (
auto vecType = dyn_cast<VectorType>(distType)) {
401 distType = getDistributedType(vecType, map,
402 map.
isEmpty() ? 1 : warpOp.getWarpSize());
404 escapingValueTypes.push_back(operand->
get().
getType());
405 escapingValueDistTypes.push_back(distType);
408 return {std::move(escapingValues), std::move(escapingValueTypes),
409 std::move(escapingValueDistTypes)};
433 unsigned maxNumElementsToExtract, PatternBenefit
b = 1)
434 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)),
435 maxNumElementsToExtract(maxNumElementsToExtract) {}
439 LogicalResult tryDistributeOp(RewriterBase &rewriter,
440 vector::TransferWriteOp writeOp,
441 WarpExecuteOnLane0Op warpOp)
const {
442 VectorType writtenVectorType = writeOp.getVectorType();
446 if (writtenVectorType.getRank() == 0)
450 AffineMap map = distributionMapFn(writeOp.getVector());
451 VectorType targetType =
452 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
458 if (writeOp.getMask()) {
465 if (!writeOp.getPermutationMap().isMinorIdentity())
468 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
473 vector::TransferWriteOp newWriteOp =
474 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
478 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
484 SmallVector<OpFoldResult> delinearizedIdSizes;
485 for (
auto [seqSize, distSize] :
486 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
487 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
488 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
490 SmallVector<Value> delinearized;
492 delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
493 rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
499 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
502 AffineMap indexMap = map.
compose(newWriteOp.getPermutationMap());
503 Location loc = newWriteOp.getLoc();
504 SmallVector<Value>
indices(newWriteOp.getIndices().begin(),
505 newWriteOp.getIndices().end());
508 bindDims(newWarpOp.getContext(), d0, d1);
509 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
512 unsigned indexPos = indexExpr.getPosition();
513 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
514 Value laneId = delinearized[vectorPos];
518 rewriter, loc, d0 + scale * d1, {
indices[indexPos], laneId});
520 newWriteOp.getIndicesMutable().assign(
indices);
526 LogicalResult tryExtractOp(RewriterBase &rewriter,
527 vector::TransferWriteOp writeOp,
528 WarpExecuteOnLane0Op warpOp)
const {
529 Location loc = writeOp.getLoc();
530 VectorType vecType = writeOp.getVectorType();
532 if (vecType.getNumElements() > maxNumElementsToExtract) {
536 "writes more elements ({0}) than allowed to extract ({1})",
537 vecType.getNumElements(), maxNumElementsToExtract));
541 if (llvm::all_of(warpOp.getOps(),
542 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
545 SmallVector<Value> yieldValues = {writeOp.getVector()};
546 SmallVector<Type> retTypes = {vecType};
547 SmallVector<size_t> newRetIndices;
549 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
553 auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc,
TypeRange(),
554 newWarpOp.getLaneid(),
555 newWarpOp.getWarpSize());
556 Block &body = secondWarpOp.getBodyRegion().front();
559 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
560 newWriteOp.getValueToStoreMutable().assign(
561 newWarpOp.getResult(newRetIndices[0]));
563 gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
568 PatternRewriter &rewriter)
const override {
569 gpu::YieldOp yield = warpOp.getTerminator();
570 Operation *lastNode = yield->getPrevNode();
571 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
575 Value maybeMask = writeOp.getMask();
576 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
577 return writeOp.getVector() == value ||
578 (maybeMask && maybeMask == value) ||
579 warpOp.isDefinedOutsideOfRegion(value);
583 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
587 if (writeOp.getMask())
590 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
600 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
601 WarpExecuteOnLane0Op warpOp,
602 vector::TransferWriteOp writeOp,
603 VectorType targetType,
604 VectorType maybeMaskType)
const {
605 assert(writeOp->getParentOp() == warpOp &&
606 "write must be nested immediately under warp");
607 OpBuilder::InsertionGuard g(rewriter);
608 SmallVector<size_t> newRetIndices;
609 WarpExecuteOnLane0Op newWarpOp;
612 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
613 TypeRange{targetType, maybeMaskType}, newRetIndices);
616 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
621 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
623 newWriteOp.getValueToStoreMutable().assign(
624 newWarpOp.getResult(newRetIndices[0]));
626 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
631 unsigned maxNumElementsToExtract = 1;
654 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
655 PatternRewriter &rewriter)
const override {
656 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
664 Value distributedVal = warpOp.getResult(operandIndex);
665 SmallVector<Value> yieldValues;
666 SmallVector<Type> retTypes;
667 Location loc = warpOp.getLoc();
670 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
672 auto operandType = cast<VectorType>(operand.
get().
getType());
674 VectorType::get(vecType.getShape(), operandType.getElementType());
677 assert(!isa<VectorType>(operandType) &&
678 "unexpected yield of vector from op with scalar result type");
679 targetType = operandType;
681 retTypes.push_back(targetType);
682 yieldValues.push_back(operand.
get());
684 SmallVector<size_t> newRetIndices;
685 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
686 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
688 SmallVector<Value> newOperands(elementWise->
getOperands().begin(),
690 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
691 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
693 OpBuilder::InsertionGuard g(rewriter);
696 rewriter, loc, elementWise, newOperands,
697 {newWarpOp.getResult(operandIndex).getType()});
721 PatternRewriter &rewriter)
const override {
722 OpOperand *yieldOperand =
727 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
734 Attribute scalarAttr = dense.getSplatValue<Attribute>();
736 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
737 Location loc = warpOp.getLoc();
739 Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
768 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
769 PatternRewriter &rewriter)
const override {
770 OpOperand *yieldOperand =
771 getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
777 if (resTy.getNumElements() !=
static_cast<int64_t
>(warpOp.getWarpSize()))
780 llvm::formatv(
"Expected result size ({0}) to be of warp size ({1})",
781 resTy.getNumElements(), warpOp.getWarpSize()));
782 VectorType newVecTy =
783 cast<VectorType>(warpOp.getResult(operandIdx).getType());
785 Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
786 newVecTy, warpOp.getLaneid());
813 PatternRewriter &rewriter)
const override {
817 OpOperand *operand =
getWarpResult(warpOp, [](Operation *op) {
819 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
823 warpOp,
"warp result is not a vector.transfer_read op");
827 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
829 read,
"source must be defined outside of the region");
832 Value distributedVal = warpOp.getResult(operandIndex);
834 SmallVector<Value, 4>
indices(read.getIndices().begin(),
835 read.getIndices().end());
836 auto sequentialType = cast<VectorType>(read.getResult().getType());
837 auto distributedType = cast<VectorType>(distributedVal.
getType());
839 AffineMap indexMap = map.
compose(read.getPermutationMap());
843 SmallVector<Value> delinearizedIds;
845 distributedType.getShape(), warpOp.getWarpSize(),
846 warpOp.getLaneid(), delinearizedIds)) {
848 read,
"cannot delinearize lane ID for distribution");
850 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
853 OpBuilder::InsertionGuard g(rewriter);
854 SmallVector<Value> additionalResults(
indices.begin(),
indices.end());
855 SmallVector<Type> additionalResultTypes(
indices.size(),
857 additionalResults.push_back(read.getPadding());
858 additionalResultTypes.push_back(read.getPadding().getType());
860 bool hasMask =
false;
861 if (read.getMask()) {
871 read,
"non-trivial permutation maps not supported");
872 VectorType maskType =
873 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
874 additionalResults.push_back(read.getMask());
875 additionalResultTypes.push_back(maskType);
878 SmallVector<size_t> newRetIndices;
880 rewriter, warpOp, additionalResults, additionalResultTypes,
882 distributedVal = newWarpOp.getResult(operandIndex);
885 SmallVector<Value> newIndices;
886 for (int64_t i = 0, e =
indices.size(); i < e; ++i)
887 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
892 bindDims(read.getContext(), d0, d1);
893 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
896 unsigned indexPos = indexExpr.getPosition();
897 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
898 int64_t scale = distributedType.getDimSize(vectorPos);
900 rewriter, read.getLoc(), d0 + scale * d1,
901 {newIndices[indexPos], delinearizedIds[vectorPos]});
905 Value newPadding = newWarpOp.getResult(newRetIndices[
indices.size()]);
908 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
910 auto newRead = vector::TransferReadOp::create(
911 rewriter, read.getLoc(), distributedVal.
getType(), read.getBase(),
912 newIndices, read.getPermutationMapAttr(), newPadding, newMask,
913 read.getInBoundsAttr());
925 PatternRewriter &rewriter)
const override {
926 SmallVector<Type> newResultTypes;
927 newResultTypes.reserve(warpOp->getNumResults());
928 SmallVector<Value> newYieldValues;
929 newYieldValues.reserve(warpOp->getNumResults());
932 gpu::YieldOp yield = warpOp.getTerminator();
943 for (OpResult
result : warpOp.getResults()) {
946 Value yieldOperand = yield.getOperand(
result.getResultNumber());
947 auto it = dedupYieldOperandPositionMap.insert(
948 std::make_pair(yieldOperand, newResultTypes.size()));
949 dedupResultPositionMap.insert(std::make_pair(
result, it.first->second));
952 newResultTypes.push_back(
result.getType());
953 newYieldValues.push_back(yieldOperand);
956 if (yield.getNumOperands() == newYieldValues.size())
960 rewriter, warpOp, newYieldValues, newResultTypes);
963 newWarpOp.getBody()->walk([&](Operation *op) {
969 SmallVector<Value> newValues;
970 newValues.reserve(warpOp->getNumResults());
971 for (OpResult
result : warpOp.getResults()) {
973 newValues.push_back(Value());
976 newWarpOp.getResult(dedupResultPositionMap.lookup(
result)));
988 PatternRewriter &rewriter)
const override {
989 gpu::YieldOp yield = warpOp.getTerminator();
991 unsigned resultIndex;
992 for (OpOperand &operand : yield->getOpOperands()) {
1001 valForwarded = operand.
get();
1005 auto arg = dyn_cast<BlockArgument>(operand.
get());
1006 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1008 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1011 valForwarded = warpOperand;
1029 PatternRewriter &rewriter)
const override {
1030 OpOperand *operand =
1036 Location loc = broadcastOp.getLoc();
1038 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1039 Value broadcastSrc = broadcastOp.getSource();
1040 Type broadcastSrcType = broadcastSrc.
getType();
1047 vector::BroadcastableToResult::Success)
1049 SmallVector<size_t> newRetIndices;
1051 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1053 Value broadcasted = vector::BroadcastOp::create(
1054 rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1066 PatternRewriter &rewriter)
const override {
1067 OpOperand *operand =
1075 auto castDistributedType =
1076 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1077 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1078 VectorType castResultType = castDistributedType;
1082 unsigned castDistributedRank = castDistributedType.getRank();
1083 unsigned castOriginalRank = castOriginalType.getRank();
1084 if (castDistributedRank < castOriginalRank) {
1085 SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1086 llvm::append_range(shape, castDistributedType.getShape());
1087 castDistributedType =
1088 VectorType::get(shape, castDistributedType.getElementType());
1091 SmallVector<size_t> newRetIndices;
1093 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1096 Value newCast = vector::ShapeCastOp::create(
1097 rewriter, oldCastOp.getLoc(), castResultType,
1098 newWarpOp->getResult(newRetIndices[0]));
1124template <
typename OpType,
1125 typename = std::enable_if_t<llvm::is_one_of<
1126 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
1129 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1130 PatternRewriter &rewriter)
const override {
1131 OpOperand *yieldOperand = getWarpResult(warpOp, (llvm::IsaPred<OpType>));
1140 !llvm::all_of(mask->
getOperands(), [&](Value value) {
1141 return warpOp.isDefinedOutsideOfRegion(value);
1145 Location loc = mask->
getLoc();
1148 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1150 ArrayRef<int64_t> seqShape = seqType.getShape();
1151 ArrayRef<int64_t> distShape = distType.getShape();
1152 SmallVector<Value> materializedOperands;
1153 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
1154 materializedOperands.append(mask->
getOperands().begin(),
1157 auto constantMaskOp = cast<vector::ConstantMaskOp>(mask);
1158 auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
1159 for (
auto dimSize : dimSizes)
1160 materializedOperands.push_back(
1167 SmallVector<Value> delinearizedIds;
1168 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1169 warpOp.getWarpSize(), warpOp.getLaneid(),
1172 mask,
"cannot delinearize lane ID for distribution");
1173 assert(!delinearizedIds.empty());
1181 SmallVector<Value> newOperands;
1182 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1189 rewriter, loc, s1 - s0 * distShape[i],
1190 {delinearizedIds[i], materializedOperands[i]});
1191 newOperands.push_back(maskDimIdx);
1195 vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1230 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1231 PatternRewriter &rewriter)
const override {
1232 OpOperand *operand =
1233 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1239 auto distributedType =
1240 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1243 if (distributedType.getRank() < 2)
1245 insertOp,
"result vector type must be 2D or higher");
1248 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1249 int64_t destDistributedDim =
1251 assert(destDistributedDim != -1 &&
"could not find distributed dimension");
1253 VectorType srcType = insertOp.getSourceVectorType();
1254 VectorType destType = insertOp.getDestVectorType();
1259 int64_t sourceDistributedDim =
1260 destDistributedDim - (destType.getRank() - srcType.getRank());
1261 if (sourceDistributedDim < 0)
1264 "distributed dimension must be in the last k dims of dest vector");
1266 if (srcType.getDimSize(sourceDistributedDim) !=
1267 destType.getDimSize(destDistributedDim))
1269 insertOp,
"distributed dimension must be fully inserted");
1270 SmallVector<int64_t> newSourceDistShape(
1271 insertOp.getSourceVectorType().getShape());
1272 newSourceDistShape[sourceDistributedDim] =
1273 distributedType.getDimSize(destDistributedDim);
1275 VectorType::get(newSourceDistShape, distributedType.getElementType());
1276 VectorType newDestTy = distributedType;
1277 SmallVector<size_t> newRetIndices;
1278 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1279 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1280 {newSourceTy, newDestTy}, newRetIndices);
1282 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1283 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1286 Value newInsert = vector::InsertStridedSliceOp::create(
1287 rewriter, insertOp.getLoc(), distributedDest.
getType(),
1288 distributedSource, distributedDest, insertOp.getOffsets(),
1289 insertOp.getStrides());
1319 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1320 PatternRewriter &rewriter)
const override {
1321 OpOperand *operand =
1322 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1328 auto distributedType =
1329 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1332 if (distributedType.getRank() < 2)
1334 extractOp,
"result vector type must be 2D or higher");
1337 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1339 assert(distributedDim != -1 &&
"could not find distributed dimension");
1341 int64_t numOfExtractedDims =
1342 static_cast<int64_t
>(extractOp.getSizes().size());
1349 if (distributedDim < numOfExtractedDims) {
1350 int64_t distributedDimOffset =
1351 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1353 int64_t distributedDimSize =
1354 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1356 if (distributedDimOffset != 0 ||
1357 distributedDimSize != yieldedType.getDimSize(distributedDim))
1359 extractOp,
"distributed dimension must be fully extracted");
1361 SmallVector<int64_t> newDistributedShape(
1362 extractOp.getSourceVectorType().getShape());
1363 newDistributedShape[distributedDim] =
1364 distributedType.getDimSize(distributedDim);
1365 auto newDistributedType =
1366 VectorType::get(newDistributedShape, distributedType.getElementType());
1367 SmallVector<size_t> newRetIndices;
1368 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1369 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1372 SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1373 extractOp.getSizes(), [](Attribute attr) { return attr; });
1375 if (distributedDim <
static_cast<int64_t
>(distributedSizes.size()))
1377 distributedType.getDimSize(distributedDim));
1381 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1382 Value newExtract = vector::ExtractStridedSliceOp::create(
1383 rewriter, extractOp.getLoc(), distributedType, distributedVec,
1384 extractOp.getOffsets(),
1385 ArrayAttr::get(rewriter.
getContext(), distributedSizes),
1386 extractOp.getStrides());
1397 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1398 PatternRewriter &rewriter)
const override {
1399 OpOperand *operand =
1400 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1405 VectorType extractSrcType = extractOp.getSourceVectorType();
1406 Location loc = extractOp.getLoc();
1409 if (extractSrcType.getRank() <= 1) {
1415 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1421 SmallVector<size_t> newRetIndices;
1422 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1423 rewriter, warpOp, {extractOp.getSource()},
1424 {extractOp.getSourceVectorType()}, newRetIndices);
1426 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1428 Value newExtract = vector::ExtractOp::create(
1429 rewriter, loc, distributedVec, extractOp.getMixedPosition());
1436 auto distributedType =
1437 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1438 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1440 assert(distributedDim != -1 &&
"could not find distributed dimension");
1441 (void)distributedDim;
1444 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1445 for (
int i = 0; i < distributedType.getRank(); ++i)
1446 newDistributedShape[i + extractOp.getNumIndices()] =
1447 distributedType.getDimSize(i);
1448 auto newDistributedType =
1449 VectorType::get(newDistributedShape, distributedType.getElementType());
1450 SmallVector<size_t> newRetIndices;
1451 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1452 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1455 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1457 Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1458 extractOp.getMixedPosition());
1468 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1469 PatternBenefit
b = 1)
1470 : WarpDistributionPattern(ctx,
b), warpShuffleFromIdxFn(std::move(fn)) {}
1471 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1472 PatternRewriter &rewriter)
const override {
1473 OpOperand *operand =
1474 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1479 VectorType extractSrcType = extractOp.getSourceVectorType();
1481 if (extractSrcType.getRank() > 1) {
1483 extractOp,
"only 0-D or 1-D source supported for now");
1487 if (!extractSrcType.getElementType().isF32() &&
1488 !extractSrcType.getElementType().isInteger(32))
1490 extractOp,
"only f32/i32 element types are supported");
1491 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1492 Type elType = extractSrcType.getElementType();
1493 VectorType distributedVecType;
1494 if (!is0dOrVec1Extract) {
1495 assert(extractSrcType.getRank() == 1 &&
1496 "expected that extract src rank is 0 or 1");
1497 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1499 int64_t elementsPerLane =
1500 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1501 distributedVecType = VectorType::get({elementsPerLane}, elType);
1503 distributedVecType = extractSrcType;
1506 SmallVector<Value> additionalResults{extractOp.getSource()};
1507 SmallVector<Type> additionalResultTypes{distributedVecType};
1508 additionalResults.append(
1509 SmallVector<Value>(extractOp.getDynamicPosition()));
1510 additionalResultTypes.append(
1511 SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1513 Location loc = extractOp.getLoc();
1514 SmallVector<size_t> newRetIndices;
1515 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1516 rewriter, warpOp, additionalResults, additionalResultTypes,
1519 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1523 if (is0dOrVec1Extract) {
1525 SmallVector<int64_t>
indices(extractSrcType.getRank(), 0);
1527 vector::ExtractOp::create(rewriter, loc, distributedVec,
indices);
1533 int64_t staticPos = extractOp.getStaticPosition()[0];
1534 OpFoldResult pos = ShapedType::isDynamic(staticPos)
1535 ? (newWarpOp->getResult(newRetIndices[1]))
1539 int64_t elementsPerLane = distributedVecType.getShape()[0];
1543 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1546 elementsPerLane == 1
1549 sym0 % elementsPerLane, pos);
1551 vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1554 Value shuffled = warpShuffleFromIdxFn(
1555 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1561 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1568 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1569 PatternRewriter &rewriter)
const override {
1570 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1575 VectorType vecType = insertOp.getDestVectorType();
1576 VectorType distrType =
1577 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1580 if (vecType.getRank() > 1) {
1582 insertOp,
"only 0-D or 1-D source supported for now");
1586 SmallVector<Value> additionalResults{insertOp.getDest(),
1587 insertOp.getValueToStore()};
1588 SmallVector<Type> additionalResultTypes{
1589 distrType, insertOp.getValueToStore().getType()};
1590 additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1591 additionalResultTypes.append(
1592 SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1594 Location loc = insertOp.getLoc();
1595 SmallVector<size_t> newRetIndices;
1596 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1597 rewriter, warpOp, additionalResults, additionalResultTypes,
1600 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1601 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1605 if (vecType.getRank() != 0) {
1606 int64_t staticPos = insertOp.getStaticPosition()[0];
1607 pos = ShapedType::isDynamic(staticPos)
1608 ? (newWarpOp->getResult(newRetIndices[2]))
1613 if (vecType == distrType) {
1615 SmallVector<OpFoldResult>
indices;
1619 newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1628 int64_t elementsPerLane = distrType.getShape()[0];
1632 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1635 rewriter, loc, sym0 % elementsPerLane, pos);
1636 Value isInsertingLane =
1637 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1638 newWarpOp.getLaneid(), insertingLane);
1641 rewriter, loc, isInsertingLane,
1643 [&](OpBuilder &builder, Location loc) {
1644 Value newInsert = vector::InsertOp::create(
1645 builder, loc, newSource, distributedVec, newPos);
1646 scf::YieldOp::create(builder, loc, newInsert);
1649 [&](OpBuilder &builder, Location loc) {
1650 scf::YieldOp::create(builder, loc, distributedVec);
1660 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1661 PatternRewriter &rewriter)
const override {
1662 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1667 Location loc = insertOp.getLoc();
1670 if (insertOp.getDestVectorType().getRank() <= 1) {
1676 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1679 SmallVector<size_t> newRetIndices;
1680 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1681 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1682 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1685 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1686 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1687 Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1689 insertOp.getMixedPosition());
1696 auto distrDestType =
1697 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1698 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1699 int64_t distrDestDim = -1;
1700 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1701 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1704 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1708 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1711 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1712 SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1719 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1720 if (distrSrcDim >= 0)
1721 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1723 VectorType::get(distrSrcShape, distrDestType.getElementType());
1726 SmallVector<size_t> newRetIndices;
1727 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1728 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1729 {distrSrcType, distrDestType}, newRetIndices);
1731 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1732 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1736 if (distrSrcDim >= 0) {
1738 newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1740 insertOp.getMixedPosition());
1743 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1744 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1748 rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1749 Value isInsertingLane =
1750 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1751 newWarpOp.getLaneid(), insertingLane);
1753 newPos[distrDestDim] %= elementsPerLane;
1754 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1755 Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1756 distributedDest, newPos);
1757 scf::YieldOp::create(builder, loc, newInsert);
1759 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1760 scf::YieldOp::create(builder, loc, distributedDest);
1762 newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1764 nonInsertingBuilder)
1801 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)) {}
1803 PatternRewriter &rewriter)
const override {
1804 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1806 Operation *lastNode = warpOpYield->getPrevNode();
1807 auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1818 SmallVector<Value> nonIfYieldValues;
1819 SmallVector<unsigned> nonIfYieldIndices;
1820 llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1821 llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1822 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1825 nonIfYieldValues.push_back(yieldOperand.
get());
1826 nonIfYieldIndices.push_back(yieldOperandIdx);
1829 OpResult ifResult = cast<OpResult>(yieldOperand.
get());
1831 ifResultMapping[yieldOperandIdx] = ifResultIdx;
1834 if (!isa<VectorType>(ifResult.
getType()))
1836 VectorType distType =
1837 cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1838 ifResultDistTypes[ifResultIdx] = distType;
1843 auto [escapingValuesThen, escapingValueInputTypesThen,
1844 escapingValueDistTypesThen] =
1845 getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1847 auto [escapingValuesElse, escapingValueInputTypesElse,
1848 escapingValueDistTypesElse] =
1849 getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1851 if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1852 llvm::is_contained(escapingValueDistTypesElse, Type{}))
1860 SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1861 newWarpOpYieldValues.append(escapingValuesThen.begin(),
1862 escapingValuesThen.end());
1863 newWarpOpYieldValues.append(escapingValuesElse.begin(),
1864 escapingValuesElse.end());
1865 SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1866 newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1867 escapingValueDistTypesThen.end());
1868 newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1869 escapingValueDistTypesElse.end());
1871 for (
auto [idx, val] :
1872 llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1873 newWarpOpYieldValues.push_back(val);
1874 newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1878 SmallVector<size_t> newIndices;
1880 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
1882 SmallVector<Type> newIfOpDistResTypes;
1883 for (
auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1884 Type distType = cast<Value>(res).getType();
1885 if (
auto vecType = dyn_cast<VectorType>(distType)) {
1886 AffineMap map = distributionMapFn(cast<Value>(res));
1888 distType = ifResultDistTypes.count(i)
1889 ? ifResultDistTypes[i]
1890 : getDistributedType(
1892 map.
isEmpty() ? 1 : newWarpOp.getWarpSize());
1894 newIfOpDistResTypes.push_back(distType);
1897 OpBuilder::InsertionGuard g(rewriter);
1899 auto newIfOp = scf::IfOp::create(
1900 rewriter, ifOp.getLoc(), newIfOpDistResTypes,
1901 newWarpOp.getResult(newIndices[0]),
static_cast<bool>(ifOp.thenBlock()),
1902 static_cast<bool>(ifOp.elseBlock()));
1903 auto encloseRegionInWarpOp =
1905 llvm::SmallSetVector<Value, 32> &escapingValues,
1906 SmallVector<Type> &escapingValueInputTypes,
1907 size_t warpResRangeStart) {
1908 OpBuilder::InsertionGuard g(rewriter);
1912 llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1913 SmallVector<Value> innerWarpInputVals;
1914 SmallVector<Type> innerWarpInputTypes;
1915 for (
size_t i = 0; i < escapingValues.size();
1916 ++i, ++warpResRangeStart) {
1917 innerWarpInputVals.push_back(
1918 newWarpOp.getResult(newIndices[warpResRangeStart]));
1919 escapeValToBlockArgIndex[escapingValues[i]] =
1920 innerWarpInputTypes.size();
1921 innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1923 auto innerWarp = WarpExecuteOnLane0Op::create(
1924 rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1925 newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1926 innerWarpInputVals, innerWarpInputTypes);
1928 innerWarp.getWarpRegion().takeBody(*oldIfBranch->
getParent());
1929 innerWarp.getWarpRegion().addArguments(
1930 innerWarpInputTypes,
1931 SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1933 SmallVector<Value> yieldOperands;
1935 yieldOperands.push_back(operand);
1939 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1941 scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1945 innerWarp.walk([&](Operation *op) {
1947 auto it = escapeValToBlockArgIndex.find(operand.
get());
1948 if (it == escapeValToBlockArgIndex.end())
1950 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
1953 mlir::vector::moveScalarUniformCode(innerWarp);
1955 encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
1956 &newIfOp.getThenRegion().front(), escapingValuesThen,
1957 escapingValueInputTypesThen, 1);
1958 if (!ifOp.getElseRegion().empty())
1959 encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
1960 &newIfOp.getElseRegion().front(),
1961 escapingValuesElse, escapingValueInputTypesElse,
1962 1 + escapingValuesThen.size());
1965 for (
auto [origIdx, newIdx] : ifResultMapping)
1967 newIfOp.getResult(newIdx), newIfOp);
2010 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)) {}
2011 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2012 PatternRewriter &rewriter)
const override {
2013 gpu::YieldOp warpOpYield = warpOp.getTerminator();
2015 Operation *lastNode = warpOpYield->getPrevNode();
2016 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
2021 auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2022 getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2024 if (llvm::is_contained(escapingValueDistTypes, Type{}))
2035 SmallVector<Value> nonForYieldedValues;
2036 SmallVector<unsigned> nonForResultIndices;
2037 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
2038 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
2039 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
2042 nonForYieldedValues.push_back(yieldOperand.
get());
2046 OpResult forResult = cast<OpResult>(yieldOperand.
get());
2051 if (!isa<VectorType>(forResult.
getType()))
2053 VectorType distType = cast<VectorType>(
2055 forResultDistTypes[forResultNumber] = distType;
2063 SmallVector<Value> newWarpOpYieldValues;
2064 SmallVector<Type> newWarpOpDistTypes;
2065 newWarpOpYieldValues.insert(
2066 newWarpOpYieldValues.end(),
2067 {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
2068 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2069 {forOp.getLowerBound().getType(),
2070 forOp.getUpperBound().getType(),
2071 forOp.getStep().getType()});
2072 for (
auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
2073 newWarpOpYieldValues.push_back(initArg);
2075 Type distType = initArg.getType();
2076 if (
auto vecType = dyn_cast<VectorType>(distType)) {
2080 AffineMap map = distributionMapFn(initArg);
2082 forResultDistTypes.count(i)
2083 ? forResultDistTypes[i]
2084 : getDistributedType(vecType, map,
2085 map.
isEmpty() ? 1 : warpOp.getWarpSize());
2087 newWarpOpDistTypes.push_back(distType);
2090 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
2091 escapingValues.begin(), escapingValues.end());
2092 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2093 escapingValueDistTypes.begin(),
2094 escapingValueDistTypes.end());
2098 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2099 newWarpOpYieldValues.push_back(v);
2100 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
2103 SmallVector<size_t> newIndices;
2104 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2105 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
2109 const unsigned initArgsStartIdx = 3;
2110 const unsigned escapingValuesStartIdx =
2112 forOp.getInitArgs().size();
2114 SmallVector<Value> newForOpOperands;
2115 for (
size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
2116 newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
2119 OpBuilder::InsertionGuard g(rewriter);
2121 auto newForOp = scf::ForOp::create(
2122 rewriter, forOp.getLoc(),
2123 newWarpOp.getResult(newIndices[0]),
2124 newWarpOp.getResult(newIndices[1]),
2125 newWarpOp.getResult(newIndices[2]), newForOpOperands,
2126 nullptr, forOp.getUnsignedCmp());
2132 SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
2133 newForOp.getRegionIterArgs().end());
2134 SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
2135 forOp.getResultTypes().end());
2139 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
2140 for (
size_t i = escapingValuesStartIdx;
2141 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2142 innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
2143 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
2144 innerWarpInputType.size();
2145 innerWarpInputType.push_back(
2146 escapingValueInputTypes[i - escapingValuesStartIdx]);
2149 auto innerWarp = WarpExecuteOnLane0Op::create(
2150 rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
2151 newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
2152 innerWarpInputType);
2155 SmallVector<Value> argMapping;
2156 argMapping.push_back(newForOp.getInductionVar());
2157 for (Value args : innerWarp.getBody()->getArguments())
2158 argMapping.push_back(args);
2160 argMapping.resize(forOp.getBody()->getNumArguments());
2161 SmallVector<Value> yieldOperands;
2162 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
2163 yieldOperands.push_back(operand);
2165 rewriter.
eraseOp(forOp.getBody()->getTerminator());
2166 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
2171 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
2175 if (!innerWarp.getResults().empty())
2176 scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
2180 for (
auto [origIdx, newIdx] : forResultMapping)
2182 newForOp.getResult(newIdx), newForOp);
2185 newForOp.walk([&](Operation *op) {
2187 auto it = argIndexMapping.find(operand.
get());
2188 if (it == argIndexMapping.end())
2190 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
2195 mlir::vector::moveScalarUniformCode(innerWarp);
2223 WarpOpReduction(MLIRContext *context,
2224 DistributedReductionFn distributedReductionFn,
2225 PatternBenefit benefit = 1)
2226 : WarpDistributionPattern(context, benefit),
2227 distributedReductionFn(std::move(distributedReductionFn)) {}
2229 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2230 PatternRewriter &rewriter)
const override {
2231 OpOperand *yieldOperand =
2232 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
2238 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
2240 if (vectorType.getRank() != 1)
2242 warpOp,
"Only rank 1 reductions can be distributed.");
2244 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2246 warpOp,
"Reduction vector dimension must match was size.");
2247 if (!reductionOp.getType().isIntOrFloat())
2249 warpOp,
"Reduction distribution currently only supports floats and "
2252 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2255 SmallVector<Value> yieldValues = {reductionOp.getVector()};
2256 SmallVector<Type> retTypes = {
2257 VectorType::get({numElements}, reductionOp.getType())};
2258 if (reductionOp.getAcc()) {
2259 yieldValues.push_back(reductionOp.getAcc());
2260 retTypes.push_back(reductionOp.getAcc().getType());
2262 SmallVector<size_t> newRetIndices;
2263 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2264 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2268 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2271 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2272 reductionOp.getKind(), newWarpOp.getWarpSize());
2273 if (reductionOp.getAcc()) {
2275 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2276 newWarpOp.getResult(newRetIndices[1]));
2283 DistributedReductionFn distributedReductionFn;
2294void mlir::vector::populateDistributeTransferWriteOpPatterns(
2297 patterns.add<WarpOpTransferWrite>(
patterns.getContext(), distributionMapFn,
2298 maxNumElementsToExtract, benefit);
2301void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2303 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
2306 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2307 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
2308 WarpOpConstant, WarpOpInsertScalar, WarpOpInsert,
2309 WarpOpCreateMask<vector::CreateMaskOp>,
2310 WarpOpCreateMask<vector::ConstantMaskOp>,
2311 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2313 patterns.add<WarpOpExtractScalar>(
patterns.getContext(), warpShuffleFromIdxFn,
2321void mlir::vector::populateDistributeReduction(
2323 const DistributedReductionFn &distributedReductionFn,
2325 patterns.add<WarpOpReduction>(
patterns.getContext(), distributedReductionFn,
2332 return llvm::all_of(op->
getOperands(), definedOutside) &&
2336void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2337 Block *body = warpOp.getBody();
2340 llvm::SmallSetVector<Operation *, 8> opsToMove;
2343 auto isDefinedOutsideOfBody = [&](
Value value) {
2345 return (definingOp && opsToMove.count(definingOp)) ||
2346 warpOp.isDefinedOutsideOfRegion(value);
2353 return isa<VectorType>(result.getType());
2355 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
2356 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.