22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24#include "llvm/Support/FormatVariadic.h"
43 VectorType distributedType) {
49 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
54 distributedType.getContext());
62 VectorType distributedType) {
63 assert(sequentialType.getRank() == distributedType.getRank() &&
64 "sequential and distributed vector types must have the same rank");
66 for (
int64_t i = 0; i < sequentialType.getRank(); ++i) {
67 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
70 assert(distributedDim == -1 &&
"found multiple distributed dims");
74 return distributedDim;
84struct DistributedLoadStoreHelper {
85 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
86 Value laneId, Value zero)
87 : sequentialVal(sequentialVal), distributedVal(distributedVal),
88 laneId(laneId), zero(zero) {
89 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
90 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
91 if (sequentialVectorType && distributedVectorType)
96 Value buildDistributedOffset(RewriterBase &
b, Location loc, int64_t index) {
97 int64_t distributedSize = distributedVectorType.getDimSize(index);
99 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
100 ArrayRef<Value>{laneId});
110 Operation *buildStore(RewriterBase &
b, Location loc, Value val,
112 assert((val == distributedVal || val == sequentialVal) &&
113 "Must store either the preregistered distributed or the "
114 "preregistered sequential value.");
116 if (!isa<VectorType>(val.
getType()))
117 return memref::StoreOp::create(
b, loc, val, buffer, zero);
121 int64_t rank = sequentialVectorType.getRank();
122 SmallVector<Value>
indices(rank, zero);
123 if (val == distributedVal) {
124 for (
auto dimExpr : distributionMap.getResults()) {
125 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126 indices[index] = buildDistributedOffset(
b, loc, index);
129 SmallVector<bool> inBounds(
indices.size(),
true);
130 return vector::TransferWriteOp::create(
132 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
155 Value buildLoad(RewriterBase &
b, Location loc, Type type, Value buffer) {
158 if (!isa<VectorType>(type))
159 return memref::LoadOp::create(
b, loc, buffer, zero);
164 assert((type == distributedVectorType || type == sequentialVectorType) &&
165 "Must store either the preregistered distributed or the "
166 "preregistered sequential type.");
167 SmallVector<Value>
indices(sequentialVectorType.getRank(), zero);
168 if (type == distributedVectorType) {
169 for (
auto dimExpr : distributionMap.getResults()) {
170 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171 indices[index] = buildDistributedOffset(
b, loc, index);
174 SmallVector<bool> inBounds(
indices.size(),
true);
175 return vector::TransferReadOp::create(
176 b, loc, cast<VectorType>(type), buffer,
indices,
178 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
181 Value sequentialVal, distributedVal, laneId, zero;
182 VectorType sequentialVectorType, distributedVectorType;
183 AffineMap distributionMap;
196 return rewriter.
create(res);
230 WarpOpToScfIfPattern(MLIRContext *context,
231 const WarpExecuteOnLane0LoweringOptions &options,
232 PatternBenefit benefit = 1)
233 : WarpDistributionPattern(context, benefit), options(options) {}
235 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
236 PatternRewriter &rewriter)
const override {
237 assert(warpOp.getBodyRegion().hasOneBlock() &&
238 "expected WarpOp with single block");
239 Block *warpOpBody = &warpOp.getBodyRegion().front();
240 Location loc = warpOp.getLoc();
243 OpBuilder::InsertionGuard g(rewriter);
248 Value isLane0 = arith::CmpIOp::create(
249 rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
250 auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
252 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
256 SmallVector<Value> bbArgReplacements;
257 for (
const auto &it : llvm::enumerate(warpOp.getArgs())) {
258 Value sequentialVal = warpOpBody->
getArgument(it.index());
259 Value distributedVal = it.value();
260 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
261 warpOp.getLaneid(), c0);
265 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
268 helper.buildStore(rewriter, loc, distributedVal, buffer);
271 bbArgReplacements.push_back(
272 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
276 if (!warpOp.getArgs().empty()) {
278 options.warpSynchronizationFn(loc, rewriter, warpOp);
282 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
288 SmallVector<Value> replacements;
289 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
290 Location yieldLoc = yieldOp.getLoc();
291 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
292 Value sequentialVal = it.value();
293 Value distributedVal = warpOp->getResult(it.index());
294 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
295 warpOp.getLaneid(), c0);
299 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
305 helper.buildStore(rewriter, loc, sequentialVal, buffer);
316 replacements.push_back(
317 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
321 if (!yieldOp.getOperands().empty()) {
323 options.warpSynchronizationFn(loc, rewriter, warpOp);
329 scf::YieldOp::create(rewriter, yieldLoc);
332 rewriter.
replaceOp(warpOp, replacements);
338 const WarpExecuteOnLane0LoweringOptions &options;
351static VectorType getDistributedType(VectorType originalType,
AffineMap map,
359 if (targetShape[position] % warpSize != 0) {
360 if (warpSize % targetShape[position] != 0) {
363 warpSize /= targetShape[position];
364 targetShape[position] = 1;
367 targetShape[position] = targetShape[position] / warpSize;
374 VectorType targetType =
375 VectorType::get(targetShape, originalType.getElementType());
385getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp,
Region &innerRegion,
387 llvm::SmallSetVector<Value, 32> escapingValues;
390 if (innerRegion.
empty())
391 return {std::move(escapingValues), std::move(escapingValueTypes),
392 std::move(escapingValueDistTypes)};
395 if (warpOp->isAncestor(parent)) {
396 if (!escapingValues.insert(operand->
get()))
399 if (
auto vecType = dyn_cast<VectorType>(distType)) {
401 distType = getDistributedType(vecType, map,
402 map.
isEmpty() ? 1 : warpOp.getWarpSize());
404 escapingValueTypes.push_back(operand->
get().
getType());
405 escapingValueDistTypes.push_back(distType);
408 return {std::move(escapingValues), std::move(escapingValueTypes),
409 std::move(escapingValueDistTypes)};
433 unsigned maxNumElementsToExtract, PatternBenefit
b = 1)
434 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)),
435 maxNumElementsToExtract(maxNumElementsToExtract) {}
439 LogicalResult tryDistributeOp(RewriterBase &rewriter,
440 vector::TransferWriteOp writeOp,
441 WarpExecuteOnLane0Op warpOp)
const {
442 VectorType writtenVectorType = writeOp.getVectorType();
446 if (writtenVectorType.getRank() == 0)
450 AffineMap map = distributionMapFn(writeOp.getVector());
451 VectorType targetType =
452 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
458 if (writeOp.getMask()) {
465 if (!writeOp.getPermutationMap().isMinorIdentity())
468 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
473 vector::TransferWriteOp newWriteOp =
474 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
478 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
484 SmallVector<OpFoldResult> delinearizedIdSizes;
485 for (
auto [seqSize, distSize] :
486 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
487 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
488 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
490 SmallVector<Value> delinearized;
492 delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
493 rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
499 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
502 AffineMap indexMap = map.
compose(newWriteOp.getPermutationMap());
503 Location loc = newWriteOp.getLoc();
504 SmallVector<Value>
indices(newWriteOp.getIndices().begin(),
505 newWriteOp.getIndices().end());
508 bindDims(newWarpOp.getContext(), d0, d1);
509 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
512 unsigned indexPos = indexExpr.getPosition();
513 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
514 Value laneId = delinearized[vectorPos];
518 rewriter, loc, d0 + scale * d1, {
indices[indexPos], laneId});
520 newWriteOp.getIndicesMutable().assign(
indices);
526 LogicalResult tryExtractOp(RewriterBase &rewriter,
527 vector::TransferWriteOp writeOp,
528 WarpExecuteOnLane0Op warpOp)
const {
529 Location loc = writeOp.getLoc();
530 VectorType vecType = writeOp.getVectorType();
532 if (vecType.getNumElements() > maxNumElementsToExtract) {
536 "writes more elements ({0}) than allowed to extract ({1})",
537 vecType.getNumElements(), maxNumElementsToExtract));
541 if (llvm::all_of(warpOp.getOps(),
542 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
545 SmallVector<Value> yieldValues = {writeOp.getVector()};
546 SmallVector<Type> retTypes = {vecType};
547 SmallVector<size_t> newRetIndices;
549 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
553 auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc,
TypeRange(),
554 newWarpOp.getLaneid(),
555 newWarpOp.getWarpSize());
556 Block &body = secondWarpOp.getBodyRegion().front();
559 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
560 newWriteOp.getValueToStoreMutable().assign(
561 newWarpOp.getResult(newRetIndices[0]));
563 gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
568 PatternRewriter &rewriter)
const override {
569 gpu::YieldOp yield = warpOp.getTerminator();
570 Operation *lastNode = yield->getPrevNode();
571 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
575 Value maybeMask = writeOp.getMask();
576 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
577 return writeOp.getVector() == value ||
578 (maybeMask && maybeMask == value) ||
579 warpOp.isDefinedOutsideOfRegion(value);
583 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
587 if (writeOp.getMask())
590 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
600 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
601 WarpExecuteOnLane0Op warpOp,
602 vector::TransferWriteOp writeOp,
603 VectorType targetType,
604 VectorType maybeMaskType)
const {
605 assert(writeOp->getParentOp() == warpOp &&
606 "write must be nested immediately under warp");
607 OpBuilder::InsertionGuard g(rewriter);
608 SmallVector<size_t> newRetIndices;
609 WarpExecuteOnLane0Op newWarpOp;
612 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
613 TypeRange{targetType, maybeMaskType}, newRetIndices);
616 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
621 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
623 newWriteOp.getValueToStoreMutable().assign(
624 newWarpOp.getResult(newRetIndices[0]));
626 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
631 unsigned maxNumElementsToExtract = 1;
654 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
655 PatternRewriter &rewriter)
const override {
656 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
664 Value distributedVal = warpOp.getResult(operandIndex);
665 SmallVector<Value> yieldValues;
666 SmallVector<Type> retTypes;
667 Location loc = warpOp.getLoc();
670 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
672 auto operandType = cast<VectorType>(operand.
get().
getType());
674 VectorType::get(vecType.getShape(), operandType.getElementType());
677 assert(!isa<VectorType>(operandType) &&
678 "unexpected yield of vector from op with scalar result type");
679 targetType = operandType;
681 retTypes.push_back(targetType);
682 yieldValues.push_back(operand.
get());
684 SmallVector<size_t> newRetIndices;
685 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
686 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
688 SmallVector<Value> newOperands(elementWise->
getOperands().begin(),
690 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
691 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
693 OpBuilder::InsertionGuard g(rewriter);
696 rewriter, loc, elementWise, newOperands,
697 {newWarpOp.getResult(operandIndex).getType()});
721 PatternRewriter &rewriter)
const override {
722 OpOperand *yieldOperand =
727 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
734 Attribute scalarAttr = dense.getSplatValue<Attribute>();
736 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
737 Location loc = warpOp.getLoc();
739 Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
768 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
769 PatternRewriter &rewriter)
const override {
770 OpOperand *yieldOperand =
771 getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
777 if (resTy.getNumElements() !=
static_cast<int64_t
>(warpOp.getWarpSize()))
780 llvm::formatv(
"Expected result size ({0}) to be of warp size ({1})",
781 resTy.getNumElements(), warpOp.getWarpSize()));
782 VectorType newVecTy =
783 cast<VectorType>(warpOp.getResult(operandIdx).getType());
785 Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
786 newVecTy, warpOp.getLaneid());
813 PatternRewriter &rewriter)
const override {
817 OpOperand *operand =
getWarpResult(warpOp, [](Operation *op) {
819 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
823 warpOp,
"warp result is not a vector.transfer_read op");
827 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
829 read,
"source must be defined outside of the region");
832 Value distributedVal = warpOp.getResult(operandIndex);
834 SmallVector<Value, 4>
indices(read.getIndices().begin(),
835 read.getIndices().end());
836 auto sequentialType = cast<VectorType>(read.getResult().getType());
837 auto distributedType = cast<VectorType>(distributedVal.
getType());
839 AffineMap indexMap = map.
compose(read.getPermutationMap());
843 SmallVector<Value> delinearizedIds;
845 distributedType.getShape(), warpOp.getWarpSize(),
846 warpOp.getLaneid(), delinearizedIds)) {
848 read,
"cannot delinearize lane ID for distribution");
850 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
853 OpBuilder::InsertionGuard g(rewriter);
854 SmallVector<Value> additionalResults(
indices.begin(),
indices.end());
855 SmallVector<Type> additionalResultTypes(
indices.size(),
857 additionalResults.push_back(read.getPadding());
858 additionalResultTypes.push_back(read.getPadding().getType());
860 bool hasMask =
false;
861 if (read.getMask()) {
871 read,
"non-trivial permutation maps not supported");
872 VectorType maskType =
873 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
874 additionalResults.push_back(read.getMask());
875 additionalResultTypes.push_back(maskType);
878 SmallVector<size_t> newRetIndices;
880 rewriter, warpOp, additionalResults, additionalResultTypes,
882 distributedVal = newWarpOp.getResult(operandIndex);
885 SmallVector<Value> newIndices;
886 for (int64_t i = 0, e =
indices.size(); i < e; ++i)
887 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
892 bindDims(read.getContext(), d0, d1);
893 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
896 unsigned indexPos = indexExpr.getPosition();
897 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
898 int64_t scale = distributedType.getDimSize(vectorPos);
900 rewriter, read.getLoc(), d0 + scale * d1,
901 {newIndices[indexPos], delinearizedIds[vectorPos]});
905 Value newPadding = newWarpOp.getResult(newRetIndices[
indices.size()]);
908 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
910 auto newRead = vector::TransferReadOp::create(
911 rewriter, read.getLoc(), distributedVal.
getType(), read.getBase(),
912 newIndices, read.getPermutationMapAttr(), newPadding, newMask,
913 read.getInBoundsAttr());
925 PatternRewriter &rewriter)
const override {
926 SmallVector<Type> newResultTypes;
927 newResultTypes.reserve(warpOp->getNumResults());
928 SmallVector<Value> newYieldValues;
929 newYieldValues.reserve(warpOp->getNumResults());
932 gpu::YieldOp yield = warpOp.getTerminator();
943 for (OpResult
result : warpOp.getResults()) {
946 Value yieldOperand = yield.getOperand(
result.getResultNumber());
947 auto it = dedupYieldOperandPositionMap.insert(
948 std::make_pair(yieldOperand, newResultTypes.size()));
949 dedupResultPositionMap.insert(std::make_pair(
result, it.first->second));
952 newResultTypes.push_back(
result.getType());
953 newYieldValues.push_back(yieldOperand);
956 if (yield.getNumOperands() == newYieldValues.size())
960 rewriter, warpOp, newYieldValues, newResultTypes);
963 newWarpOp.getBody()->walk([&](Operation *op) {
969 SmallVector<Value> newValues;
970 newValues.reserve(warpOp->getNumResults());
971 for (OpResult
result : warpOp.getResults()) {
973 newValues.push_back(Value());
976 newWarpOp.getResult(dedupResultPositionMap.lookup(
result)));
988 PatternRewriter &rewriter)
const override {
989 gpu::YieldOp yield = warpOp.getTerminator();
991 unsigned resultIndex;
992 for (OpOperand &operand : yield->getOpOperands()) {
1001 valForwarded = operand.
get();
1005 auto arg = dyn_cast<BlockArgument>(operand.
get());
1006 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1008 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1011 valForwarded = warpOperand;
1029 PatternRewriter &rewriter)
const override {
1030 OpOperand *operand =
1036 Location loc = broadcastOp.getLoc();
1038 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1039 Value broadcastSrc = broadcastOp.getSource();
1040 Type broadcastSrcType = broadcastSrc.
getType();
1047 vector::BroadcastableToResult::Success)
1049 SmallVector<size_t> newRetIndices;
1051 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1053 Value broadcasted = vector::BroadcastOp::create(
1054 rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1066 PatternRewriter &rewriter)
const override {
1067 OpOperand *operand =
1075 auto castDistributedType =
1076 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1077 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1078 VectorType castResultType = castDistributedType;
1080 FailureOr<VectorType> maybeSrcType =
1081 inferDistributedSrcType(castDistributedType, castOriginalType);
1082 if (
failed(maybeSrcType))
1084 castDistributedType = *maybeSrcType;
1086 SmallVector<size_t> newRetIndices;
1088 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1091 Value newCast = vector::ShapeCastOp::create(
1092 rewriter, oldCastOp.getLoc(), castResultType,
1093 newWarpOp->getResult(newRetIndices[0]));
1099 static FailureOr<VectorType>
1100 inferDistributedSrcType(VectorType distributedType, VectorType srcType) {
1101 unsigned distributedRank = distributedType.getRank();
1102 unsigned srcRank = srcType.getRank();
1103 if (distributedRank == srcRank)
1105 return distributedType;
1106 if (distributedRank < srcRank) {
1109 SmallVector<int64_t> shape(srcRank - distributedRank, 1);
1110 llvm::append_range(shape, distributedType.getShape());
1111 return VectorType::get(shape, distributedType.getElementType());
1123 return VectorType::get(distributedType.getNumElements(),
1124 srcType.getElementType());
1129 unsigned excessDims = distributedRank - srcRank;
1130 ArrayRef<int64_t> shape = distributedType.getShape();
1131 if (!llvm::all_of(shape.take_front(excessDims),
1132 [](int64_t d) { return d == 1; }))
1134 return VectorType::get(shape.drop_front(excessDims),
1135 distributedType.getElementType());
1159template <
typename OpType,
1160 typename = std::enable_if_t<llvm::is_one_of<
1161 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
1164 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1165 PatternRewriter &rewriter)
const override {
1166 OpOperand *yieldOperand = getWarpResult(warpOp, (llvm::IsaPred<OpType>));
1175 !llvm::all_of(mask->
getOperands(), [&](Value value) {
1176 return warpOp.isDefinedOutsideOfRegion(value);
1180 Location loc = mask->
getLoc();
1183 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1185 ArrayRef<int64_t> seqShape = seqType.getShape();
1186 ArrayRef<int64_t> distShape = distType.getShape();
1187 SmallVector<Value> materializedOperands;
1188 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
1189 materializedOperands.append(mask->
getOperands().begin(),
1192 auto constantMaskOp = cast<vector::ConstantMaskOp>(mask);
1193 auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
1194 for (
auto dimSize : dimSizes)
1195 materializedOperands.push_back(
1202 SmallVector<Value> delinearizedIds;
1203 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1204 warpOp.getWarpSize(), warpOp.getLaneid(),
1207 mask,
"cannot delinearize lane ID for distribution");
1208 assert(!delinearizedIds.empty());
1216 SmallVector<Value> newOperands;
1217 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1223 Value maskDimIdx = affine::makeComposedAffineApply(
1224 rewriter, loc, s1 - s0 * distShape[i],
1225 {delinearizedIds[i], materializedOperands[i]});
1226 newOperands.push_back(maskDimIdx);
1230 vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1265 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1266 PatternRewriter &rewriter)
const override {
1267 OpOperand *operand =
1268 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1274 auto distributedType =
1275 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1278 if (distributedType.getRank() < 2)
1280 insertOp,
"result vector type must be 2D or higher");
1283 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1284 int64_t destDistributedDim =
1286 assert(destDistributedDim != -1 &&
"could not find distributed dimension");
1288 VectorType srcType = insertOp.getSourceVectorType();
1289 VectorType destType = insertOp.getDestVectorType();
1294 int64_t sourceDistributedDim =
1295 destDistributedDim - (destType.getRank() - srcType.getRank());
1296 if (sourceDistributedDim < 0)
1299 "distributed dimension must be in the last k dims of dest vector");
1301 if (srcType.getDimSize(sourceDistributedDim) !=
1302 destType.getDimSize(destDistributedDim))
1304 insertOp,
"distributed dimension must be fully inserted");
1305 SmallVector<int64_t> newSourceDistShape(
1306 insertOp.getSourceVectorType().getShape());
1307 newSourceDistShape[sourceDistributedDim] =
1308 distributedType.getDimSize(destDistributedDim);
1310 VectorType::get(newSourceDistShape, distributedType.getElementType());
1311 VectorType newDestTy = distributedType;
1312 SmallVector<size_t> newRetIndices;
1313 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1314 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1315 {newSourceTy, newDestTy}, newRetIndices);
1317 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1318 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1321 Value newInsert = vector::InsertStridedSliceOp::create(
1322 rewriter, insertOp.getLoc(), distributedDest.
getType(),
1323 distributedSource, distributedDest, insertOp.getOffsets(),
1324 insertOp.getStrides());
1354 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1355 PatternRewriter &rewriter)
const override {
1356 OpOperand *operand =
1357 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1363 auto distributedType =
1364 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1367 if (distributedType.getRank() < 2)
1369 extractOp,
"result vector type must be 2D or higher");
1372 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1374 assert(distributedDim != -1 &&
"could not find distributed dimension");
1376 int64_t numOfExtractedDims =
1377 static_cast<int64_t
>(extractOp.getSizes().size());
1384 if (distributedDim < numOfExtractedDims) {
1385 int64_t distributedDimOffset =
1386 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1388 int64_t distributedDimSize =
1389 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1391 if (distributedDimOffset != 0 ||
1392 distributedDimSize != yieldedType.getDimSize(distributedDim))
1394 extractOp,
"distributed dimension must be fully extracted");
1396 SmallVector<int64_t> newDistributedShape(
1397 extractOp.getSourceVectorType().getShape());
1398 newDistributedShape[distributedDim] =
1399 distributedType.getDimSize(distributedDim);
1400 auto newDistributedType =
1401 VectorType::get(newDistributedShape, distributedType.getElementType());
1402 SmallVector<size_t> newRetIndices;
1403 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1404 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1407 SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1408 extractOp.getSizes(), [](Attribute attr) { return attr; });
1410 if (distributedDim <
static_cast<int64_t
>(distributedSizes.size()))
1412 distributedType.getDimSize(distributedDim));
1416 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1417 Value newExtract = vector::ExtractStridedSliceOp::create(
1418 rewriter, extractOp.getLoc(), distributedType, distributedVec,
1419 extractOp.getOffsets(),
1420 ArrayAttr::get(rewriter.
getContext(), distributedSizes),
1421 extractOp.getStrides());
1432 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1433 PatternRewriter &rewriter)
const override {
1434 OpOperand *operand =
1435 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1440 VectorType extractSrcType = extractOp.getSourceVectorType();
1441 Location loc = extractOp.getLoc();
1444 if (extractSrcType.getRank() <= 1) {
1450 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1456 SmallVector<size_t> newRetIndices;
1457 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1458 rewriter, warpOp, {extractOp.getSource()},
1459 {extractOp.getSourceVectorType()}, newRetIndices);
1461 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1463 Value newExtract = vector::ExtractOp::create(
1464 rewriter, loc, distributedVec, extractOp.getMixedPosition());
1471 auto distributedType =
1472 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1473 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1475 assert(distributedDim != -1 &&
"could not find distributed dimension");
1476 (void)distributedDim;
1479 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1480 for (
int i = 0; i < distributedType.getRank(); ++i)
1481 newDistributedShape[i + extractOp.getNumIndices()] =
1482 distributedType.getDimSize(i);
1483 auto newDistributedType =
1484 VectorType::get(newDistributedShape, distributedType.getElementType());
1485 SmallVector<size_t> newRetIndices;
1486 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1487 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1490 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1492 Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1493 extractOp.getMixedPosition());
1503 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1504 PatternBenefit
b = 1)
1505 : WarpDistributionPattern(ctx,
b), warpShuffleFromIdxFn(std::move(fn)) {}
1506 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1507 PatternRewriter &rewriter)
const override {
1508 OpOperand *operand =
1509 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1514 VectorType extractSrcType = extractOp.getSourceVectorType();
1516 if (extractSrcType.getRank() > 1) {
1518 extractOp,
"only 0-D or 1-D source supported for now");
1522 if (!extractSrcType.getElementType().isF32() &&
1523 !extractSrcType.getElementType().isInteger(32))
1525 extractOp,
"only f32/i32 element types are supported");
1526 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1527 Type elType = extractSrcType.getElementType();
1528 VectorType distributedVecType;
1529 if (!is0dOrVec1Extract) {
1530 assert(extractSrcType.getRank() == 1 &&
1531 "expected that extract src rank is 0 or 1");
1532 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1534 int64_t elementsPerLane =
1535 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1536 distributedVecType = VectorType::get({elementsPerLane}, elType);
1538 distributedVecType = extractSrcType;
1541 SmallVector<Value> additionalResults{extractOp.getSource()};
1542 SmallVector<Type> additionalResultTypes{distributedVecType};
1543 additionalResults.append(
1544 SmallVector<Value>(extractOp.getDynamicPosition()));
1545 additionalResultTypes.append(
1546 SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1548 Location loc = extractOp.getLoc();
1549 SmallVector<size_t> newRetIndices;
1550 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1551 rewriter, warpOp, additionalResults, additionalResultTypes,
1554 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1558 if (is0dOrVec1Extract) {
1560 SmallVector<int64_t>
indices(extractSrcType.getRank(), 0);
1562 vector::ExtractOp::create(rewriter, loc, distributedVec,
indices);
1568 int64_t staticPos = extractOp.getStaticPosition()[0];
1569 OpFoldResult pos = ShapedType::isDynamic(staticPos)
1570 ? (newWarpOp->getResult(newRetIndices[1]))
1574 int64_t elementsPerLane = distributedVecType.getShape()[0];
1577 Value broadcastFromTid = affine::makeComposedAffineApply(
1578 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1581 elementsPerLane == 1
1583 : affine::makeComposedAffineApply(rewriter, loc,
1584 sym0 % elementsPerLane, pos);
1586 vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1589 Value shuffled = warpShuffleFromIdxFn(
1590 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1596 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1603 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1604 PatternRewriter &rewriter)
const override {
1605 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1610 VectorType vecType = insertOp.getDestVectorType();
1611 VectorType distrType =
1612 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1615 if (vecType.getRank() > 1) {
1617 insertOp,
"only 0-D or 1-D source supported for now");
1621 SmallVector<Value> additionalResults{insertOp.getDest(),
1622 insertOp.getValueToStore()};
1623 SmallVector<Type> additionalResultTypes{
1624 distrType, insertOp.getValueToStore().getType()};
1625 additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1626 additionalResultTypes.append(
1627 SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1629 Location loc = insertOp.getLoc();
1630 SmallVector<size_t> newRetIndices;
1631 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1632 rewriter, warpOp, additionalResults, additionalResultTypes,
1635 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1636 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1640 if (vecType.getRank() != 0) {
1641 int64_t staticPos = insertOp.getStaticPosition()[0];
1642 pos = ShapedType::isDynamic(staticPos)
1643 ? (newWarpOp->getResult(newRetIndices[2]))
1648 if (vecType == distrType) {
1650 SmallVector<OpFoldResult>
indices;
1654 newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1663 int64_t elementsPerLane = distrType.getShape()[0];
1666 Value insertingLane = affine::makeComposedAffineApply(
1667 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1669 OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
1670 rewriter, loc, sym0 % elementsPerLane, pos);
1671 Value isInsertingLane =
1672 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1673 newWarpOp.getLaneid(), insertingLane);
1676 rewriter, loc, isInsertingLane,
1678 [&](OpBuilder &builder, Location loc) {
1679 Value newInsert = vector::InsertOp::create(
1680 builder, loc, newSource, distributedVec, newPos);
1681 scf::YieldOp::create(builder, loc, newInsert);
1684 [&](OpBuilder &builder, Location loc) {
1685 scf::YieldOp::create(builder, loc, distributedVec);
1695 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1696 PatternRewriter &rewriter)
const override {
1697 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1702 Location loc = insertOp.getLoc();
1705 if (insertOp.getDestVectorType().getRank() <= 1) {
1711 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1714 SmallVector<size_t> newRetIndices;
1715 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1716 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1717 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1720 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1721 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1722 Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1724 insertOp.getMixedPosition());
1731 auto distrDestType =
1732 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1733 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1734 int64_t distrDestDim = -1;
1735 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1736 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1739 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1743 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1746 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1747 SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1754 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1755 if (distrSrcDim >= 0)
1756 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1758 VectorType::get(distrSrcShape, distrDestType.getElementType());
1761 SmallVector<size_t> newRetIndices;
1762 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1763 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1764 {distrSrcType, distrDestType}, newRetIndices);
1766 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1767 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1771 if (distrSrcDim >= 0) {
1773 newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1775 insertOp.getMixedPosition());
1778 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1779 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1783 rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1784 Value isInsertingLane =
1785 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1786 newWarpOp.getLaneid(), insertingLane);
1788 newPos[distrDestDim] %= elementsPerLane;
1789 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1790 Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1791 distributedDest, newPos);
1792 scf::YieldOp::create(builder, loc, newInsert);
1794 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1795 scf::YieldOp::create(builder, loc, distributedDest);
1797 newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1799 nonInsertingBuilder)
1836 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)) {}
1838 PatternRewriter &rewriter)
const override {
1839 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1841 Operation *lastNode = warpOpYield->getPrevNode();
1842 auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1853 SmallVector<Value> nonIfYieldValues;
1854 SmallVector<unsigned> nonIfYieldIndices;
1855 llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1856 llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1857 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1860 nonIfYieldValues.push_back(yieldOperand.
get());
1861 nonIfYieldIndices.push_back(yieldOperandIdx);
1864 OpResult ifResult = cast<OpResult>(yieldOperand.
get());
1866 ifResultMapping[yieldOperandIdx] = ifResultIdx;
1869 if (!isa<VectorType>(ifResult.
getType()))
1871 VectorType distType =
1872 cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1873 ifResultDistTypes[ifResultIdx] = distType;
1878 auto [escapingValuesThen, escapingValueInputTypesThen,
1879 escapingValueDistTypesThen] =
1880 getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1882 auto [escapingValuesElse, escapingValueInputTypesElse,
1883 escapingValueDistTypesElse] =
1884 getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1886 if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1887 llvm::is_contained(escapingValueDistTypesElse, Type{}))
1895 SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1896 newWarpOpYieldValues.append(escapingValuesThen.begin(),
1897 escapingValuesThen.end());
1898 newWarpOpYieldValues.append(escapingValuesElse.begin(),
1899 escapingValuesElse.end());
1900 SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1901 newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1902 escapingValueDistTypesThen.end());
1903 newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1904 escapingValueDistTypesElse.end());
1906 for (
auto [idx, val] :
1907 llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1908 newWarpOpYieldValues.push_back(val);
1909 newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1913 SmallVector<size_t> newIndices;
1915 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
1917 SmallVector<Type> newIfOpDistResTypes;
1918 for (
auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1919 Type distType = cast<Value>(res).getType();
1920 if (
auto vecType = dyn_cast<VectorType>(distType)) {
1921 AffineMap map = distributionMapFn(cast<Value>(res));
1923 distType = ifResultDistTypes.count(i)
1924 ? ifResultDistTypes[i]
1925 : getDistributedType(
1927 map.
isEmpty() ? 1 : newWarpOp.getWarpSize());
1929 newIfOpDistResTypes.push_back(distType);
1932 OpBuilder::InsertionGuard g(rewriter);
1934 auto newIfOp = scf::IfOp::create(
1935 rewriter, ifOp.getLoc(), newIfOpDistResTypes,
1936 newWarpOp.getResult(newIndices[0]),
static_cast<bool>(ifOp.thenBlock()),
1937 static_cast<bool>(ifOp.elseBlock()));
1938 auto encloseRegionInWarpOp =
1940 llvm::SmallSetVector<Value, 32> &escapingValues,
1941 SmallVector<Type> &escapingValueInputTypes,
1942 size_t warpResRangeStart) {
1943 OpBuilder::InsertionGuard g(rewriter);
1947 llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1948 SmallVector<Value> innerWarpInputVals;
1949 SmallVector<Type> innerWarpInputTypes;
1950 for (
size_t i = 0; i < escapingValues.size();
1951 ++i, ++warpResRangeStart) {
1952 innerWarpInputVals.push_back(
1953 newWarpOp.getResult(newIndices[warpResRangeStart]));
1954 escapeValToBlockArgIndex[escapingValues[i]] =
1955 innerWarpInputTypes.size();
1956 innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1958 auto innerWarp = WarpExecuteOnLane0Op::create(
1959 rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1960 newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1961 innerWarpInputVals, innerWarpInputTypes);
1963 innerWarp.getWarpRegion().takeBody(*oldIfBranch->
getParent());
1964 innerWarp.getWarpRegion().addArguments(
1965 innerWarpInputTypes,
1966 SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1968 SmallVector<Value> yieldOperands;
1970 yieldOperands.push_back(operand);
1974 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1976 scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1980 innerWarp.walk([&](Operation *op) {
1982 auto it = escapeValToBlockArgIndex.find(operand.
get());
1983 if (it == escapeValToBlockArgIndex.end())
1985 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
1988 mlir::vector::moveScalarUniformCode(innerWarp);
1990 encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
1991 &newIfOp.getThenRegion().front(), escapingValuesThen,
1992 escapingValueInputTypesThen, 1);
1993 if (!ifOp.getElseRegion().empty())
1994 encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
1995 &newIfOp.getElseRegion().front(),
1996 escapingValuesElse, escapingValueInputTypesElse,
1997 1 + escapingValuesThen.size());
2000 for (
auto [origIdx, newIdx] : ifResultMapping)
2002 newIfOp.getResult(newIdx), newIfOp);
2045 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)) {}
2046 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2047 PatternRewriter &rewriter)
const override {
2048 gpu::YieldOp warpOpYield = warpOp.getTerminator();
2050 Operation *lastNode = warpOpYield->getPrevNode();
2051 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
2056 auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2057 getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2059 if (llvm::is_contained(escapingValueDistTypes, Type{}))
2070 SmallVector<Value> nonForYieldedValues;
2071 SmallVector<unsigned> nonForResultIndices;
2072 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
2073 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
2074 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
2077 nonForYieldedValues.push_back(yieldOperand.
get());
2081 OpResult forResult = cast<OpResult>(yieldOperand.
get());
2086 if (!isa<VectorType>(forResult.
getType()))
2088 VectorType distType = cast<VectorType>(
2090 forResultDistTypes[forResultNumber] = distType;
2098 SmallVector<Value> newWarpOpYieldValues;
2099 SmallVector<Type> newWarpOpDistTypes;
2100 newWarpOpYieldValues.insert(
2101 newWarpOpYieldValues.end(),
2102 {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
2103 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2104 {forOp.getLowerBound().getType(),
2105 forOp.getUpperBound().getType(),
2106 forOp.getStep().getType()});
2107 for (
auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
2108 newWarpOpYieldValues.push_back(initArg);
2110 Type distType = initArg.getType();
2111 if (
auto vecType = dyn_cast<VectorType>(distType)) {
2115 AffineMap map = distributionMapFn(initArg);
2117 forResultDistTypes.count(i)
2118 ? forResultDistTypes[i]
2119 : getDistributedType(vecType, map,
2120 map.
isEmpty() ? 1 : warpOp.getWarpSize());
2122 newWarpOpDistTypes.push_back(distType);
2125 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
2126 escapingValues.begin(), escapingValues.end());
2127 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2128 escapingValueDistTypes.begin(),
2129 escapingValueDistTypes.end());
2133 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2134 newWarpOpYieldValues.push_back(v);
2135 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
2138 SmallVector<size_t> newIndices;
2139 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2140 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
2144 const unsigned initArgsStartIdx = 3;
2145 const unsigned escapingValuesStartIdx =
2147 forOp.getInitArgs().size();
2149 SmallVector<Value> newForOpOperands;
2150 for (
size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
2151 newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
2154 OpBuilder::InsertionGuard g(rewriter);
2156 auto newForOp = scf::ForOp::create(
2157 rewriter, forOp.getLoc(),
2158 newWarpOp.getResult(newIndices[0]),
2159 newWarpOp.getResult(newIndices[1]),
2160 newWarpOp.getResult(newIndices[2]), newForOpOperands,
2161 nullptr, forOp.getUnsignedCmp());
2167 SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
2168 newForOp.getRegionIterArgs().end());
2169 SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
2170 forOp.getResultTypes().end());
2174 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
2175 for (
size_t i = escapingValuesStartIdx;
2176 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2177 innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
2178 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
2179 innerWarpInputType.size();
2180 innerWarpInputType.push_back(
2181 escapingValueInputTypes[i - escapingValuesStartIdx]);
2184 auto innerWarp = WarpExecuteOnLane0Op::create(
2185 rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
2186 newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
2187 innerWarpInputType);
2190 SmallVector<Value> argMapping;
2191 argMapping.push_back(newForOp.getInductionVar());
2192 for (Value args : innerWarp.getBody()->getArguments())
2193 argMapping.push_back(args);
2195 argMapping.resize(forOp.getBody()->getNumArguments());
2196 SmallVector<Value> yieldOperands;
2197 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
2198 yieldOperands.push_back(operand);
2200 rewriter.
eraseOp(forOp.getBody()->getTerminator());
2201 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
2206 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
2210 if (!innerWarp.getResults().empty())
2211 scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
2215 for (
auto [origIdx, newIdx] : forResultMapping)
2217 newForOp.getResult(newIdx), newForOp);
2220 newForOp.walk([&](Operation *op) {
2222 auto it = argIndexMapping.find(operand.
get());
2223 if (it == argIndexMapping.end())
2225 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
2230 mlir::vector::moveScalarUniformCode(innerWarp);
2258 WarpOpReduction(MLIRContext *context,
2259 DistributedReductionFn distributedReductionFn,
2260 PatternBenefit benefit = 1)
2261 : WarpDistributionPattern(context, benefit),
2262 distributedReductionFn(std::move(distributedReductionFn)) {}
2264 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2265 PatternRewriter &rewriter)
const override {
2266 OpOperand *yieldOperand =
2267 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
2273 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
2275 if (vectorType.getRank() != 1)
2277 warpOp,
"Only rank 1 reductions can be distributed.");
2279 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2281 warpOp,
"Reduction vector dimension must match was size.");
2282 if (!reductionOp.getType().isIntOrFloat())
2284 warpOp,
"Reduction distribution currently only supports floats and "
2287 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2290 SmallVector<Value> yieldValues = {reductionOp.getVector()};
2291 SmallVector<Type> retTypes = {
2292 VectorType::get({numElements}, reductionOp.getType())};
2293 if (reductionOp.getAcc()) {
2294 yieldValues.push_back(reductionOp.getAcc());
2295 retTypes.push_back(reductionOp.getAcc().getType());
2297 SmallVector<size_t> newRetIndices;
2298 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2299 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2303 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2306 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2307 reductionOp.getKind(), newWarpOp.getWarpSize());
2308 if (reductionOp.getAcc()) {
2310 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2311 newWarpOp.getResult(newRetIndices[1]));
2318 DistributedReductionFn distributedReductionFn;
2329void mlir::vector::populateDistributeTransferWriteOpPatterns(
2332 patterns.
add<WarpOpTransferWrite>(patterns.
getContext(), distributionMapFn,
2333 maxNumElementsToExtract, benefit);
2336void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2338 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
2340 patterns.
add<WarpOpTransferRead>(patterns.
getContext(), readBenefit);
2341 patterns.
add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2342 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
2343 WarpOpConstant, WarpOpInsertScalar, WarpOpInsert,
2344 WarpOpCreateMask<vector::CreateMaskOp>,
2345 WarpOpCreateMask<vector::ConstantMaskOp>,
2346 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2348 patterns.
add<WarpOpExtractScalar>(patterns.
getContext(), warpShuffleFromIdxFn,
2350 patterns.
add<WarpOpScfForOp>(patterns.
getContext(), distributionMapFn,
2352 patterns.
add<WarpOpScfIfOp>(patterns.
getContext(), distributionMapFn,
2356void mlir::vector::populateDistributeReduction(
2358 const DistributedReductionFn &distributedReductionFn,
2360 patterns.
add<WarpOpReduction>(patterns.
getContext(), distributedReductionFn,
2367 return llvm::all_of(op->
getOperands(), definedOutside) &&
2371void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2372 Block *body = warpOp.getBody();
2375 llvm::SmallSetVector<Operation *, 8> opsToMove;
2378 auto isDefinedOutsideOfBody = [&](
Value value) {
2380 return (definingOp && opsToMove.count(definingOp)) ||
2381 warpOp.isDefinedOutsideOfRegion(value);
2388 return isa<VectorType>(result.getType());
2390 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
2391 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.