23#include "llvm/ADT/SetVector.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVectorExtras.h"
26#include "llvm/Support/FormatVariadic.h"
45 VectorType distributedType) {
51 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
52 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
56 distributedType.getContext());
64 VectorType distributedType) {
65 assert(sequentialType.getRank() == distributedType.getRank() &&
66 "sequential and distributed vector types must have the same rank");
68 for (
int64_t i = 0; i < sequentialType.getRank(); ++i) {
69 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
72 assert(distributedDim == -1 &&
"found multiple distributed dims");
76 return distributedDim;
86struct DistributedLoadStoreHelper {
87 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
88 Value laneId, Value zero)
89 : sequentialVal(sequentialVal), distributedVal(distributedVal),
90 laneId(laneId), zero(zero) {
91 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
92 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
93 if (sequentialVectorType && distributedVectorType)
98 Value buildDistributedOffset(RewriterBase &
b, Location loc, int64_t index) {
99 int64_t distributedSize = distributedVectorType.getDimSize(index);
101 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
102 ArrayRef<Value>{laneId});
112 Operation *buildStore(RewriterBase &
b, Location loc, Value val,
114 assert((val == distributedVal || val == sequentialVal) &&
115 "Must store either the preregistered distributed or the "
116 "preregistered sequential value.");
118 if (!isa<VectorType>(val.
getType()))
119 return memref::StoreOp::create(
b, loc, val, buffer, zero);
123 int64_t rank = sequentialVectorType.getRank();
124 SmallVector<Value>
indices(rank, zero);
125 if (val == distributedVal) {
126 for (
auto dimExpr : distributionMap.getResults()) {
127 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
128 indices[index] = buildDistributedOffset(
b, loc, index);
131 SmallVector<bool> inBounds(
indices.size(),
true);
132 return vector::TransferWriteOp::create(
134 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
157 Value buildLoad(RewriterBase &
b, Location loc, Type type, Value buffer) {
160 if (!isa<VectorType>(type))
161 return memref::LoadOp::create(
b, loc, buffer, zero);
166 assert((type == distributedVectorType || type == sequentialVectorType) &&
167 "Must store either the preregistered distributed or the "
168 "preregistered sequential type.");
169 SmallVector<Value>
indices(sequentialVectorType.getRank(), zero);
170 if (type == distributedVectorType) {
171 for (
auto dimExpr : distributionMap.getResults()) {
172 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
173 indices[index] = buildDistributedOffset(
b, loc, index);
176 SmallVector<bool> inBounds(
indices.size(),
true);
177 return vector::TransferReadOp::create(
178 b, loc, cast<VectorType>(type), buffer,
indices,
180 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
183 Value sequentialVal, distributedVal, laneId, zero;
184 VectorType sequentialVectorType, distributedVectorType;
185 AffineMap distributionMap;
198 return rewriter.
create(res);
232 WarpOpToScfIfPattern(MLIRContext *context,
233 const WarpExecuteOnLane0LoweringOptions &options,
234 PatternBenefit benefit = 1)
235 : WarpDistributionPattern(context, benefit), options(options) {}
237 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
238 PatternRewriter &rewriter)
const override {
239 assert(warpOp.getBodyRegion().hasOneBlock() &&
240 "expected WarpOp with single block");
241 Block *warpOpBody = &warpOp.getBodyRegion().front();
242 Location loc = warpOp.getLoc();
245 OpBuilder::InsertionGuard g(rewriter);
250 Value isLane0 = arith::CmpIOp::create(
251 rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
252 auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
254 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
258 SmallVector<Value> bbArgReplacements;
259 for (
const auto &it : llvm::enumerate(warpOp.getArgs())) {
260 Value sequentialVal = warpOpBody->
getArgument(it.index());
261 Value distributedVal = it.value();
262 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
263 warpOp.getLaneid(), c0);
267 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
270 helper.buildStore(rewriter, loc, distributedVal, buffer);
273 bbArgReplacements.push_back(
274 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
278 if (!warpOp.getArgs().empty()) {
280 options.warpSynchronizationFn(loc, rewriter, warpOp);
284 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
290 SmallVector<Value> replacements;
291 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
292 Location yieldLoc = yieldOp.getLoc();
293 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
294 Value sequentialVal = it.value();
295 Value distributedVal = warpOp->getResult(it.index());
296 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
297 warpOp.getLaneid(), c0);
301 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
307 helper.buildStore(rewriter, loc, sequentialVal, buffer);
318 replacements.push_back(
319 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
323 if (!yieldOp.getOperands().empty()) {
325 options.warpSynchronizationFn(loc, rewriter, warpOp);
331 scf::YieldOp::create(rewriter, yieldLoc);
334 rewriter.
replaceOp(warpOp, replacements);
340 const WarpExecuteOnLane0LoweringOptions &options;
353static VectorType getDistributedType(VectorType originalType,
AffineMap map,
361 if (targetShape[position] % warpSize != 0) {
362 if (warpSize % targetShape[position] != 0) {
365 warpSize /= targetShape[position];
366 targetShape[position] = 1;
369 targetShape[position] = targetShape[position] / warpSize;
376 VectorType targetType =
377 VectorType::get(targetShape, originalType.getElementType());
387getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp,
Region &innerRegion,
389 llvm::SmallSetVector<Value, 32> escapingValues;
392 if (innerRegion.
empty())
393 return {std::move(escapingValues), std::move(escapingValueTypes),
394 std::move(escapingValueDistTypes)};
397 if (warpOp->isAncestor(parent)) {
398 if (!escapingValues.insert(operand->
get()))
401 if (
auto vecType = dyn_cast<VectorType>(distType)) {
403 distType = getDistributedType(vecType, map,
404 map.
isEmpty() ? 1 : warpOp.getWarpSize());
406 escapingValueTypes.push_back(operand->
get().
getType());
407 escapingValueDistTypes.push_back(distType);
410 return {std::move(escapingValues), std::move(escapingValueTypes),
411 std::move(escapingValueDistTypes)};
435 unsigned maxNumElementsToExtract, PatternBenefit
b = 1)
436 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)),
437 maxNumElementsToExtract(maxNumElementsToExtract) {}
441 LogicalResult tryDistributeOp(RewriterBase &rewriter,
442 vector::TransferWriteOp writeOp,
443 WarpExecuteOnLane0Op warpOp)
const {
444 VectorType writtenVectorType = writeOp.getVectorType();
448 if (writtenVectorType.getRank() == 0)
452 AffineMap map = distributionMapFn(writeOp.getVector());
453 VectorType targetType =
454 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
460 if (writeOp.getMask()) {
467 if (!writeOp.getPermutationMap().isMinorIdentity())
470 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
475 vector::TransferWriteOp newWriteOp =
476 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
480 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
486 SmallVector<OpFoldResult> delinearizedIdSizes;
487 for (
auto [seqSize, distSize] :
488 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
489 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
490 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
492 SmallVector<Value> delinearized;
494 delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
495 rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
501 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
504 AffineMap indexMap = map.
compose(newWriteOp.getPermutationMap());
505 Location loc = newWriteOp.getLoc();
506 SmallVector<Value>
indices(newWriteOp.getIndices().begin(),
507 newWriteOp.getIndices().end());
510 bindDims(newWarpOp.getContext(), d0, d1);
511 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
514 unsigned indexPos = indexExpr.getPosition();
515 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
516 Value laneId = delinearized[vectorPos];
520 rewriter, loc, d0 + scale * d1, {
indices[indexPos], laneId});
522 newWriteOp.getIndicesMutable().assign(
indices);
528 LogicalResult tryExtractOp(RewriterBase &rewriter,
529 vector::TransferWriteOp writeOp,
530 WarpExecuteOnLane0Op warpOp)
const {
531 Location loc = writeOp.getLoc();
532 VectorType vecType = writeOp.getVectorType();
534 if (vecType.getNumElements() > maxNumElementsToExtract) {
538 "writes more elements ({0}) than allowed to extract ({1})",
539 vecType.getNumElements(), maxNumElementsToExtract));
543 if (llvm::all_of(warpOp.getOps(),
544 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
547 SmallVector<Value> yieldValues = {writeOp.getVector()};
548 SmallVector<Type> retTypes = {vecType};
549 SmallVector<size_t> newRetIndices;
551 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
555 auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc,
TypeRange(),
556 newWarpOp.getLaneid(),
557 newWarpOp.getWarpSize());
558 Block &body = secondWarpOp.getBodyRegion().front();
561 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
562 newWriteOp.getValueToStoreMutable().assign(
563 newWarpOp.getResult(newRetIndices[0]));
565 gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
570 PatternRewriter &rewriter)
const override {
571 gpu::YieldOp yield = warpOp.getTerminator();
572 Operation *lastNode = yield->getPrevNode();
573 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
577 Value maybeMask = writeOp.getMask();
578 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
579 return writeOp.getVector() == value ||
580 (maybeMask && maybeMask == value) ||
581 warpOp.isDefinedOutsideOfRegion(value);
585 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
589 if (writeOp.getMask())
592 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
602 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
603 WarpExecuteOnLane0Op warpOp,
604 vector::TransferWriteOp writeOp,
605 VectorType targetType,
606 VectorType maybeMaskType)
const {
607 assert(writeOp->getParentOp() == warpOp &&
608 "write must be nested immediately under warp");
609 OpBuilder::InsertionGuard g(rewriter);
610 SmallVector<size_t> newRetIndices;
611 WarpExecuteOnLane0Op newWarpOp;
614 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
615 TypeRange{targetType, maybeMaskType}, newRetIndices);
618 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
623 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
625 newWriteOp.getValueToStoreMutable().assign(
626 newWarpOp.getResult(newRetIndices[0]));
628 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
633 unsigned maxNumElementsToExtract = 1;
656 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
657 PatternRewriter &rewriter)
const override {
658 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
666 Value distributedVal = warpOp.getResult(operandIndex);
667 SmallVector<Value> yieldValues;
668 SmallVector<Type> retTypes;
669 Location loc = warpOp.getLoc();
672 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
674 auto operandType = cast<VectorType>(operand.
get().
getType());
676 VectorType::get(vecType.getShape(), operandType.getElementType());
679 assert(!isa<VectorType>(operandType) &&
680 "unexpected yield of vector from op with scalar result type");
681 targetType = operandType;
683 retTypes.push_back(targetType);
684 yieldValues.push_back(operand.
get());
686 SmallVector<size_t> newRetIndices;
687 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
688 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
690 SmallVector<Value> newOperands(elementWise->
getOperands().begin(),
692 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
693 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
695 OpBuilder::InsertionGuard g(rewriter);
698 rewriter, loc, elementWise, newOperands,
699 {newWarpOp.getResult(operandIndex).getType()});
723 PatternRewriter &rewriter)
const override {
724 OpOperand *yieldOperand =
729 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
736 Attribute scalarAttr = dense.getSplatValue<Attribute>();
738 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
739 Location loc = warpOp.getLoc();
741 Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
770 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
771 PatternRewriter &rewriter)
const override {
772 OpOperand *yieldOperand =
773 getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
779 if (resTy.getNumElements() !=
static_cast<int64_t
>(warpOp.getWarpSize()))
782 llvm::formatv(
"Expected result size ({0}) to be of warp size ({1})",
783 resTy.getNumElements(), warpOp.getWarpSize()));
784 VectorType newVecTy =
785 cast<VectorType>(warpOp.getResult(operandIdx).getType());
787 Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
788 newVecTy, warpOp.getLaneid());
815 PatternRewriter &rewriter)
const override {
819 OpOperand *operand =
getWarpResult(warpOp, [](Operation *op) {
821 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
825 warpOp,
"warp result is not a vector.transfer_read op");
829 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
831 read,
"source must be defined outside of the region");
834 Value distributedVal = warpOp.getResult(operandIndex);
836 SmallVector<Value, 4>
indices(read.getIndices().begin(),
837 read.getIndices().end());
838 auto sequentialType = cast<VectorType>(read.getResult().getType());
839 auto distributedType = cast<VectorType>(distributedVal.
getType());
841 AffineMap indexMap = map.
compose(read.getPermutationMap());
845 SmallVector<Value> delinearizedIds;
847 distributedType.getShape(), warpOp.getWarpSize(),
848 warpOp.getLaneid(), delinearizedIds)) {
850 read,
"cannot delinearize lane ID for distribution");
852 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
855 OpBuilder::InsertionGuard g(rewriter);
856 SmallVector<Value> additionalResults(
indices.begin(),
indices.end());
857 SmallVector<Type> additionalResultTypes(
indices.size(),
859 additionalResults.push_back(read.getPadding());
860 additionalResultTypes.push_back(read.getPadding().getType());
862 bool hasMask =
false;
863 if (read.getMask()) {
873 read,
"non-trivial permutation maps not supported");
874 VectorType maskType =
875 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
876 additionalResults.push_back(read.getMask());
877 additionalResultTypes.push_back(maskType);
880 SmallVector<size_t> newRetIndices;
882 rewriter, warpOp, additionalResults, additionalResultTypes,
884 distributedVal = newWarpOp.getResult(operandIndex);
887 SmallVector<Value> newIndices;
888 for (int64_t i = 0, e =
indices.size(); i < e; ++i)
889 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
894 bindDims(read.getContext(), d0, d1);
895 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
898 unsigned indexPos = indexExpr.getPosition();
899 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
900 int64_t scale = distributedType.getDimSize(vectorPos);
902 rewriter, read.getLoc(), d0 + scale * d1,
903 {newIndices[indexPos], delinearizedIds[vectorPos]});
907 Value newPadding = newWarpOp.getResult(newRetIndices[
indices.size()]);
910 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
912 auto newRead = vector::TransferReadOp::create(
913 rewriter, read.getLoc(), distributedVal.
getType(), read.getBase(),
914 newIndices, read.getPermutationMapAttr(), newPadding, newMask,
915 read.getInBoundsAttr());
927 PatternRewriter &rewriter)
const override {
928 SmallVector<Type> newResultTypes;
929 newResultTypes.reserve(warpOp->getNumResults());
930 SmallVector<Value> newYieldValues;
931 newYieldValues.reserve(warpOp->getNumResults());
934 gpu::YieldOp yield = warpOp.getTerminator();
945 for (OpResult
result : warpOp.getResults()) {
948 Value yieldOperand = yield.getOperand(
result.getResultNumber());
949 auto it = dedupYieldOperandPositionMap.insert(
950 std::make_pair(yieldOperand, newResultTypes.size()));
951 dedupResultPositionMap.insert(std::make_pair(
result, it.first->second));
954 newResultTypes.push_back(
result.getType());
955 newYieldValues.push_back(yieldOperand);
958 if (yield.getNumOperands() == newYieldValues.size())
962 rewriter, warpOp, newYieldValues, newResultTypes);
965 newWarpOp.getBody()->walk([&](Operation *op) {
971 SmallVector<Value> newValues;
972 newValues.reserve(warpOp->getNumResults());
973 for (OpResult
result : warpOp.getResults()) {
975 newValues.push_back(Value());
978 newWarpOp.getResult(dedupResultPositionMap.lookup(
result)));
990 PatternRewriter &rewriter)
const override {
991 gpu::YieldOp yield = warpOp.getTerminator();
993 unsigned resultIndex;
994 for (OpOperand &operand : yield->getOpOperands()) {
1003 valForwarded = operand.
get();
1007 auto arg = dyn_cast<BlockArgument>(operand.
get());
1008 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1010 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1013 valForwarded = warpOperand;
1031 PatternRewriter &rewriter)
const override {
1032 OpOperand *operand =
1038 Location loc = broadcastOp.getLoc();
1040 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1041 Value broadcastSrc = broadcastOp.getSource();
1042 Type broadcastSrcType = broadcastSrc.
getType();
1049 vector::BroadcastableToResult::Success)
1051 SmallVector<size_t> newRetIndices;
1053 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1055 Value broadcasted = vector::BroadcastOp::create(
1056 rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1068 PatternRewriter &rewriter)
const override {
1069 OpOperand *operand =
1077 auto castDistributedType =
1078 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1079 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1080 VectorType castResultType = castDistributedType;
1082 FailureOr<VectorType> maybeSrcType =
1083 inferDistributedSrcType(castDistributedType, castOriginalType);
1084 if (
failed(maybeSrcType))
1086 castDistributedType = *maybeSrcType;
1088 SmallVector<size_t> newRetIndices;
1090 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1093 Value newCast = vector::ShapeCastOp::create(
1094 rewriter, oldCastOp.getLoc(), castResultType,
1095 newWarpOp->getResult(newRetIndices[0]));
1101 static FailureOr<VectorType>
1102 inferDistributedSrcType(VectorType distributedType, VectorType srcType) {
1103 unsigned distributedRank = distributedType.getRank();
1104 unsigned srcRank = srcType.getRank();
1105 if (distributedRank == srcRank)
1107 return distributedType;
1108 if (distributedRank < srcRank) {
1111 SmallVector<int64_t> shape(srcRank - distributedRank, 1);
1112 llvm::append_range(shape, distributedType.getShape());
1113 return VectorType::get(shape, distributedType.getElementType());
1125 return VectorType::get(distributedType.getNumElements(),
1126 srcType.getElementType());
1131 unsigned excessDims = distributedRank - srcRank;
1132 ArrayRef<int64_t> shape = distributedType.getShape();
1133 if (!llvm::all_of(shape.take_front(excessDims),
1134 [](int64_t d) { return d == 1; }))
1136 return VectorType::get(shape.drop_front(excessDims),
1137 distributedType.getElementType());
1161template <
typename OpType,
1162 typename = std::enable_if_t<llvm::is_one_of<
1163 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
1166 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1167 PatternRewriter &rewriter)
const override {
1168 OpOperand *yieldOperand = getWarpResult(warpOp, (llvm::IsaPred<OpType>));
1177 !llvm::all_of(mask->
getOperands(), [&](Value value) {
1178 return warpOp.isDefinedOutsideOfRegion(value);
1182 Location loc = mask->
getLoc();
1185 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1187 ArrayRef<int64_t> seqShape = seqType.getShape();
1188 ArrayRef<int64_t> distShape = distType.getShape();
1189 SmallVector<Value> materializedOperands;
1190 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
1191 materializedOperands.append(mask->
getOperands().begin(),
1194 auto constantMaskOp = cast<vector::ConstantMaskOp>(mask);
1195 auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
1196 for (
auto dimSize : dimSizes)
1197 materializedOperands.push_back(
1204 SmallVector<Value> delinearizedIds;
1205 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1206 warpOp.getWarpSize(), warpOp.getLaneid(),
1209 mask,
"cannot delinearize lane ID for distribution");
1210 assert(!delinearizedIds.empty());
1218 SmallVector<Value> newOperands;
1219 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1225 Value maskDimIdx = affine::makeComposedAffineApply(
1226 rewriter, loc, s1 - s0 * distShape[i],
1227 {delinearizedIds[i], materializedOperands[i]});
1228 newOperands.push_back(maskDimIdx);
1232 vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1267 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1268 PatternRewriter &rewriter)
const override {
1269 OpOperand *operand =
1270 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1276 auto distributedType =
1277 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1280 if (distributedType.getRank() < 2)
1282 insertOp,
"result vector type must be 2D or higher");
1285 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1286 int64_t destDistributedDim =
1288 assert(destDistributedDim != -1 &&
"could not find distributed dimension");
1290 VectorType srcType = insertOp.getSourceVectorType();
1291 VectorType destType = insertOp.getDestVectorType();
1296 int64_t sourceDistributedDim =
1297 destDistributedDim - (destType.getRank() - srcType.getRank());
1298 if (sourceDistributedDim < 0)
1301 "distributed dimension must be in the last k dims of dest vector");
1303 if (srcType.getDimSize(sourceDistributedDim) !=
1304 destType.getDimSize(destDistributedDim))
1306 insertOp,
"distributed dimension must be fully inserted");
1307 SmallVector<int64_t> newSourceDistShape(
1308 insertOp.getSourceVectorType().getShape());
1309 newSourceDistShape[sourceDistributedDim] =
1310 distributedType.getDimSize(destDistributedDim);
1312 VectorType::get(newSourceDistShape, distributedType.getElementType());
1313 VectorType newDestTy = distributedType;
1314 SmallVector<size_t> newRetIndices;
1315 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1316 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1317 {newSourceTy, newDestTy}, newRetIndices);
1319 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1320 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1323 Value newInsert = vector::InsertStridedSliceOp::create(
1324 rewriter, insertOp.getLoc(), distributedDest.
getType(),
1325 distributedSource, distributedDest, insertOp.getOffsets(),
1326 insertOp.getStrides());
1356 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1357 PatternRewriter &rewriter)
const override {
1358 OpOperand *operand =
1359 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1365 auto distributedType =
1366 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1369 if (distributedType.getRank() < 2)
1371 extractOp,
"result vector type must be 2D or higher");
1374 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1376 assert(distributedDim != -1 &&
"could not find distributed dimension");
1378 int64_t numOfExtractedDims =
1379 static_cast<int64_t
>(extractOp.getSizes().size());
1386 if (distributedDim < numOfExtractedDims) {
1387 int64_t distributedDimOffset =
1388 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1390 int64_t distributedDimSize =
1391 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1393 if (distributedDimOffset != 0 ||
1394 distributedDimSize != yieldedType.getDimSize(distributedDim))
1396 extractOp,
"distributed dimension must be fully extracted");
1398 SmallVector<int64_t> newDistributedShape(
1399 extractOp.getSourceVectorType().getShape());
1400 newDistributedShape[distributedDim] =
1401 distributedType.getDimSize(distributedDim);
1402 auto newDistributedType =
1403 VectorType::get(newDistributedShape, distributedType.getElementType());
1404 SmallVector<size_t> newRetIndices;
1405 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1406 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1409 SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1410 extractOp.getSizes(), [](Attribute attr) { return attr; });
1412 if (distributedDim <
static_cast<int64_t
>(distributedSizes.size()))
1414 distributedType.getDimSize(distributedDim));
1418 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1419 Value newExtract = vector::ExtractStridedSliceOp::create(
1420 rewriter, extractOp.getLoc(), distributedType, distributedVec,
1421 extractOp.getOffsets(),
1422 ArrayAttr::get(rewriter.
getContext(), distributedSizes),
1423 extractOp.getStrides());
1434 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1435 PatternRewriter &rewriter)
const override {
1436 OpOperand *operand =
1437 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1442 VectorType extractSrcType = extractOp.getSourceVectorType();
1443 Location loc = extractOp.getLoc();
1446 if (extractSrcType.getRank() <= 1) {
1452 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1458 SmallVector<size_t> newRetIndices;
1459 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1460 rewriter, warpOp, {extractOp.getSource()},
1461 {extractOp.getSourceVectorType()}, newRetIndices);
1463 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1465 Value newExtract = vector::ExtractOp::create(
1466 rewriter, loc, distributedVec, extractOp.getMixedPosition());
1473 auto distributedType =
1474 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1475 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1477 assert(distributedDim != -1 &&
"could not find distributed dimension");
1478 (void)distributedDim;
1481 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1482 for (
int i = 0; i < distributedType.getRank(); ++i)
1483 newDistributedShape[i + extractOp.getNumIndices()] =
1484 distributedType.getDimSize(i);
1485 auto newDistributedType =
1486 VectorType::get(newDistributedShape, distributedType.getElementType());
1487 SmallVector<size_t> newRetIndices;
1488 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1489 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1492 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1494 Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1495 extractOp.getMixedPosition());
1505 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1506 PatternBenefit
b = 1)
1507 : WarpDistributionPattern(ctx,
b), warpShuffleFromIdxFn(std::move(fn)) {}
1508 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1509 PatternRewriter &rewriter)
const override {
1510 OpOperand *operand =
1511 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1516 VectorType extractSrcType = extractOp.getSourceVectorType();
1518 if (extractSrcType.getRank() > 1) {
1520 extractOp,
"only 0-D or 1-D source supported for now");
1524 if (!extractSrcType.getElementType().isF32() &&
1525 !extractSrcType.getElementType().isInteger(32))
1527 extractOp,
"only f32/i32 element types are supported");
1528 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1529 Type elType = extractSrcType.getElementType();
1530 VectorType distributedVecType;
1531 if (!is0dOrVec1Extract) {
1532 assert(extractSrcType.getRank() == 1 &&
1533 "expected that extract src rank is 0 or 1");
1534 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1536 int64_t elementsPerLane =
1537 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1538 distributedVecType = VectorType::get({elementsPerLane}, elType);
1540 distributedVecType = extractSrcType;
1543 SmallVector<Value> additionalResults{extractOp.getSource()};
1544 SmallVector<Type> additionalResultTypes{distributedVecType};
1545 additionalResults.append(
1546 SmallVector<Value>(extractOp.getDynamicPosition()));
1547 additionalResultTypes.append(
1548 SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1550 Location loc = extractOp.getLoc();
1551 SmallVector<size_t> newRetIndices;
1552 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1553 rewriter, warpOp, additionalResults, additionalResultTypes,
1556 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1560 if (is0dOrVec1Extract) {
1562 SmallVector<int64_t>
indices(extractSrcType.getRank(), 0);
1564 vector::ExtractOp::create(rewriter, loc, distributedVec,
indices);
1570 int64_t staticPos = extractOp.getStaticPosition()[0];
1571 OpFoldResult pos = ShapedType::isDynamic(staticPos)
1572 ? (newWarpOp->getResult(newRetIndices[1]))
1576 int64_t elementsPerLane = distributedVecType.getShape()[0];
1579 Value broadcastFromTid = affine::makeComposedAffineApply(
1580 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1583 elementsPerLane == 1
1585 : affine::makeComposedAffineApply(rewriter, loc,
1586 sym0 % elementsPerLane, pos);
1588 vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1591 Value shuffled = warpShuffleFromIdxFn(
1592 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1598 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1605 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1606 PatternRewriter &rewriter)
const override {
1607 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1612 VectorType vecType = insertOp.getDestVectorType();
1613 VectorType distrType =
1614 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1617 if (vecType.getRank() > 1) {
1619 insertOp,
"only 0-D or 1-D source supported for now");
1623 SmallVector<Value> additionalResults{insertOp.getDest(),
1624 insertOp.getValueToStore()};
1625 SmallVector<Type> additionalResultTypes{
1626 distrType, insertOp.getValueToStore().getType()};
1627 additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1628 additionalResultTypes.append(
1629 SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1631 Location loc = insertOp.getLoc();
1632 SmallVector<size_t> newRetIndices;
1633 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1634 rewriter, warpOp, additionalResults, additionalResultTypes,
1637 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1638 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1642 if (vecType.getRank() != 0) {
1643 int64_t staticPos = insertOp.getStaticPosition()[0];
1644 pos = ShapedType::isDynamic(staticPos)
1645 ? (newWarpOp->getResult(newRetIndices[2]))
1650 if (vecType == distrType) {
1652 SmallVector<OpFoldResult>
indices;
1656 newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1665 int64_t elementsPerLane = distrType.getShape()[0];
1668 Value insertingLane = affine::makeComposedAffineApply(
1669 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1671 OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
1672 rewriter, loc, sym0 % elementsPerLane, pos);
1673 Value isInsertingLane =
1674 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1675 newWarpOp.getLaneid(), insertingLane);
1678 rewriter, loc, isInsertingLane,
1680 [&](OpBuilder &builder, Location loc) {
1681 Value newInsert = vector::InsertOp::create(
1682 builder, loc, newSource, distributedVec, newPos);
1683 scf::YieldOp::create(builder, loc, newInsert);
1686 [&](OpBuilder &builder, Location loc) {
1687 scf::YieldOp::create(builder, loc, distributedVec);
1697 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1698 PatternRewriter &rewriter)
const override {
1699 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1704 Location loc = insertOp.getLoc();
1707 if (insertOp.getDestVectorType().getRank() <= 1) {
1713 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1716 SmallVector<size_t> newRetIndices;
1717 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1718 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1719 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1722 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1723 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1724 Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1726 insertOp.getMixedPosition());
1733 auto distrDestType =
1734 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1735 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1736 int64_t distrDestDim = -1;
1737 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1738 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1741 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1745 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1748 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1749 SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1756 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1757 if (distrSrcDim >= 0)
1758 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1760 VectorType::get(distrSrcShape, distrDestType.getElementType());
1763 SmallVector<size_t> newRetIndices;
1764 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1765 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1766 {distrSrcType, distrDestType}, newRetIndices);
1768 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1769 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1773 if (distrSrcDim >= 0) {
1775 newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1777 insertOp.getMixedPosition());
1780 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1781 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1785 rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1786 Value isInsertingLane =
1787 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1788 newWarpOp.getLaneid(), insertingLane);
1790 newPos[distrDestDim] %= elementsPerLane;
1791 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1792 Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1793 distributedDest, newPos);
1794 scf::YieldOp::create(builder, loc, newInsert);
1796 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1797 scf::YieldOp::create(builder, loc, distributedDest);
1799 newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1801 nonInsertingBuilder)
1838 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)) {}
1840 PatternRewriter &rewriter)
const override {
1841 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1843 Operation *lastNode = warpOpYield->getPrevNode();
1844 auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1855 SmallVector<Value> nonIfYieldValues;
1856 SmallVector<unsigned> nonIfYieldIndices;
1857 llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1858 llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1859 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1862 nonIfYieldValues.push_back(yieldOperand.
get());
1863 nonIfYieldIndices.push_back(yieldOperandIdx);
1866 OpResult ifResult = cast<OpResult>(yieldOperand.
get());
1868 ifResultMapping[yieldOperandIdx] = ifResultIdx;
1871 if (!isa<VectorType>(ifResult.
getType()))
1873 VectorType distType =
1874 cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1875 ifResultDistTypes[ifResultIdx] = distType;
1880 auto [escapingValuesThen, escapingValueInputTypesThen,
1881 escapingValueDistTypesThen] =
1882 getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1884 auto [escapingValuesElse, escapingValueInputTypesElse,
1885 escapingValueDistTypesElse] =
1886 getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1888 if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1889 llvm::is_contained(escapingValueDistTypesElse, Type{}))
1897 SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1898 newWarpOpYieldValues.append(escapingValuesThen.begin(),
1899 escapingValuesThen.end());
1900 newWarpOpYieldValues.append(escapingValuesElse.begin(),
1901 escapingValuesElse.end());
1902 SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1903 newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1904 escapingValueDistTypesThen.end());
1905 newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1906 escapingValueDistTypesElse.end());
1908 for (
auto [idx, val] :
1909 llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1910 newWarpOpYieldValues.push_back(val);
1911 newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1915 SmallVector<size_t> newIndices;
1917 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
1919 SmallVector<Type> newIfOpDistResTypes;
1920 for (
auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1921 Type distType = cast<Value>(res).getType();
1922 if (
auto vecType = dyn_cast<VectorType>(distType)) {
1923 AffineMap map = distributionMapFn(cast<Value>(res));
1925 distType = ifResultDistTypes.count(i)
1926 ? ifResultDistTypes[i]
1927 : getDistributedType(
1929 map.
isEmpty() ? 1 : newWarpOp.getWarpSize());
1931 newIfOpDistResTypes.push_back(distType);
1934 OpBuilder::InsertionGuard g(rewriter);
1936 auto newIfOp = scf::IfOp::create(
1937 rewriter, ifOp.getLoc(), newIfOpDistResTypes,
1938 newWarpOp.getResult(newIndices[0]),
static_cast<bool>(ifOp.thenBlock()),
1939 static_cast<bool>(ifOp.elseBlock()));
1940 auto encloseRegionInWarpOp =
1942 llvm::SmallSetVector<Value, 32> &escapingValues,
1943 SmallVector<Type> &escapingValueInputTypes,
1944 size_t warpResRangeStart) {
1945 OpBuilder::InsertionGuard g(rewriter);
1949 llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1950 SmallVector<Value> innerWarpInputVals;
1951 SmallVector<Type> innerWarpInputTypes;
1952 for (
size_t i = 0; i < escapingValues.size();
1953 ++i, ++warpResRangeStart) {
1954 innerWarpInputVals.push_back(
1955 newWarpOp.getResult(newIndices[warpResRangeStart]));
1956 escapeValToBlockArgIndex[escapingValues[i]] =
1957 innerWarpInputTypes.size();
1958 innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1960 auto innerWarp = WarpExecuteOnLane0Op::create(
1961 rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1962 newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1963 innerWarpInputVals, innerWarpInputTypes);
1965 innerWarp.getWarpRegion().takeBody(*oldIfBranch->
getParent());
1966 innerWarp.getWarpRegion().addArguments(
1967 innerWarpInputTypes,
1968 SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1970 SmallVector<Value> yieldOperands;
1972 yieldOperands.push_back(operand);
1976 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1978 scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1982 innerWarp.walk([&](Operation *op) {
1983 SmallVector<std::pair<unsigned, Value>> replacements;
1985 auto it = escapeValToBlockArgIndex.find(operand.
get());
1986 if (it == escapeValToBlockArgIndex.end())
1988 replacements.emplace_back(
1990 innerWarp.getBodyRegion().getArgument(it->second));
1992 if (!replacements.empty()) {
1994 for (
auto [idx, newVal] : replacements)
1999 mlir::vector::moveScalarUniformCode(innerWarp);
2001 encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
2002 &newIfOp.getThenRegion().front(), escapingValuesThen,
2003 escapingValueInputTypesThen, 1);
2004 if (!ifOp.getElseRegion().empty())
2005 encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
2006 &newIfOp.getElseRegion().front(),
2007 escapingValuesElse, escapingValueInputTypesElse,
2008 1 + escapingValuesThen.size());
2011 for (
auto [origIdx, newIdx] : ifResultMapping)
2013 newIfOp.getResult(newIdx), newIfOp);
2021 OpBuilder::InsertionGuard guard(rewriter);
2023 Operation *yield = newWarpOp.getTerminator();
2025 for (
auto [origIdx, ifResultIdx] : ifResultMapping) {
2026 Value poison = ub::PoisonOp::create(
2027 rewriter, ifOp.getLoc(), ifOp.getResult(ifResultIdx).getType());
2076 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)) {}
2077 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2078 PatternRewriter &rewriter)
const override {
2079 gpu::YieldOp warpOpYield = warpOp.getTerminator();
2081 Operation *lastNode = warpOpYield->getPrevNode();
2082 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
2087 auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2088 getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2090 if (llvm::is_contained(escapingValueDistTypes, Type{}))
2101 SmallVector<Value> nonForYieldedValues;
2102 SmallVector<unsigned> nonForResultIndices;
2103 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
2104 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
2105 llvm::SmallBitVector forResultsMapped(forOp.getNumResults());
2106 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
2109 nonForYieldedValues.push_back(yieldOperand.
get());
2113 OpResult forResult = cast<OpResult>(yieldOperand.
get());
2116 forResultsMapped.set(forResultNumber);
2119 if (!isa<VectorType>(forResult.
getType()))
2121 VectorType distType = cast<VectorType>(
2123 forResultDistTypes[forResultNumber] = distType;
2131 SmallVector<Value> newWarpOpYieldValues;
2132 SmallVector<Type> newWarpOpDistTypes;
2133 newWarpOpYieldValues.insert(
2134 newWarpOpYieldValues.end(),
2135 {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
2136 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2137 {forOp.getLowerBound().getType(),
2138 forOp.getUpperBound().getType(),
2139 forOp.getStep().getType()});
2140 for (
auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
2141 newWarpOpYieldValues.push_back(initArg);
2143 Type distType = initArg.getType();
2144 if (
auto vecType = dyn_cast<VectorType>(distType)) {
2148 AffineMap map = distributionMapFn(initArg);
2150 forResultDistTypes.count(i)
2151 ? forResultDistTypes[i]
2152 : getDistributedType(vecType, map,
2153 map.
isEmpty() ? 1 : warpOp.getWarpSize());
2155 newWarpOpDistTypes.push_back(distType);
2158 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
2159 escapingValues.begin(), escapingValues.end());
2160 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2161 escapingValueDistTypes.begin(),
2162 escapingValueDistTypes.end());
2166 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2167 newWarpOpYieldValues.push_back(v);
2168 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
2171 SmallVector<size_t> newIndices;
2172 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2173 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
2177 const unsigned initArgsStartIdx = 3;
2178 const unsigned escapingValuesStartIdx =
2180 forOp.getInitArgs().size();
2182 SmallVector<Value> newForOpOperands;
2183 for (
size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
2184 newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
2187 OpBuilder::InsertionGuard g(rewriter);
2189 auto newForOp = scf::ForOp::create(
2190 rewriter, forOp.getLoc(),
2191 newWarpOp.getResult(newIndices[0]),
2192 newWarpOp.getResult(newIndices[1]),
2193 newWarpOp.getResult(newIndices[2]), newForOpOperands,
2194 nullptr, forOp.getUnsignedCmp());
2200 SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
2201 newForOp.getRegionIterArgs().end());
2202 SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
2203 forOp.getResultTypes().end());
2207 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
2208 for (
size_t i = escapingValuesStartIdx;
2209 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2210 innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
2211 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
2212 innerWarpInputType.size();
2213 innerWarpInputType.push_back(
2214 escapingValueInputTypes[i - escapingValuesStartIdx]);
2217 auto innerWarp = WarpExecuteOnLane0Op::create(
2218 rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
2219 newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
2220 innerWarpInputType);
2223 SmallVector<Value> argMapping;
2224 argMapping.push_back(newForOp.getInductionVar());
2225 for (Value args : innerWarp.getBody()->getArguments())
2226 argMapping.push_back(args);
2228 argMapping.resize(forOp.getBody()->getNumArguments());
2229 SmallVector<Value> yieldOperands;
2230 for (Value operand : forOp.getBody()->getTerminator()->getOperands()) {
2231 if (BlockArgument blockArg = dyn_cast<BlockArgument>(operand);
2232 blockArg && blockArg.getOwner() == forOp.getBody()) {
2233 yieldOperands.push_back(argMapping[blockArg.getArgNumber()]);
2236 yieldOperands.push_back(operand);
2239 rewriter.
eraseOp(forOp.getBody()->getTerminator());
2240 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
2245 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
2249 if (!innerWarp.getResults().empty())
2250 scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
2254 for (
auto [origIdx, newIdx] : forResultMapping)
2256 newForOp.getResult(newIdx), newForOp);
2261 for (OpResult
result : forOp.getResults()) {
2262 if (forResultsMapped.test(
result.getResultNumber()))
2264 result, forOp.getInitArgs()[
result.getResultNumber()]);
2270 newForOp.walk([&](Operation *op) {
2271 SmallVector<std::pair<unsigned, Value>> replacements;
2273 auto it = argIndexMapping.find(operand.
get());
2274 if (it == argIndexMapping.end())
2276 replacements.emplace_back(
2278 innerWarp.getBodyRegion().getArgument(it->second));
2280 if (!replacements.empty()) {
2282 for (
auto [idx, newVal] : replacements)
2289 mlir::vector::moveScalarUniformCode(innerWarp);
2317 WarpOpReduction(MLIRContext *context,
2318 DistributedReductionFn distributedReductionFn,
2319 PatternBenefit benefit = 1)
2320 : WarpDistributionPattern(context, benefit),
2321 distributedReductionFn(std::move(distributedReductionFn)) {}
2323 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2324 PatternRewriter &rewriter)
const override {
2325 OpOperand *yieldOperand =
2326 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
2332 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
2334 if (vectorType.getRank() != 1)
2336 warpOp,
"Only rank 1 reductions can be distributed.");
2338 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2340 warpOp,
"Reduction vector dimension must match was size.");
2341 if (!reductionOp.getType().isIntOrFloat())
2343 warpOp,
"Reduction distribution currently only supports floats and "
2346 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2349 SmallVector<Value> yieldValues = {reductionOp.getVector()};
2350 SmallVector<Type> retTypes = {
2351 VectorType::get({numElements}, reductionOp.getType())};
2352 if (reductionOp.getAcc()) {
2353 yieldValues.push_back(reductionOp.getAcc());
2354 retTypes.push_back(reductionOp.getAcc().getType());
2356 SmallVector<size_t> newRetIndices;
2357 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2358 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2362 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2365 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2366 reductionOp.getKind(), newWarpOp.getWarpSize());
2367 if (reductionOp.getAcc()) {
2369 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2370 newWarpOp.getResult(newRetIndices[1]));
2377 DistributedReductionFn distributedReductionFn;
2388void mlir::vector::populateDistributeTransferWriteOpPatterns(
2391 patterns.
add<WarpOpTransferWrite>(patterns.
getContext(), distributionMapFn,
2392 maxNumElementsToExtract, benefit);
2395void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2397 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
2399 patterns.
add<WarpOpTransferRead>(patterns.
getContext(), readBenefit);
2400 patterns.
add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2401 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
2402 WarpOpConstant, WarpOpInsertScalar, WarpOpInsert,
2403 WarpOpCreateMask<vector::CreateMaskOp>,
2404 WarpOpCreateMask<vector::ConstantMaskOp>,
2405 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2407 patterns.
add<WarpOpExtractScalar>(patterns.
getContext(), warpShuffleFromIdxFn,
2409 patterns.
add<WarpOpScfForOp>(patterns.
getContext(), distributionMapFn,
2411 patterns.
add<WarpOpScfIfOp>(patterns.
getContext(), distributionMapFn,
2415void mlir::vector::populateDistributeReduction(
2417 const DistributedReductionFn &distributedReductionFn,
2419 patterns.
add<WarpOpReduction>(patterns.
getContext(), distributedReductionFn,
2426 return llvm::all_of(op->
getOperands(), definedOutside) &&
2430void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2431 Block *body = warpOp.getBody();
2434 llvm::SmallSetVector<Operation *, 8> opsToMove;
2437 auto isDefinedOutsideOfBody = [&](
Value value) {
2439 return (definingOp && opsToMove.count(definingOp)) ||
2440 warpOp.isDefinedOutsideOfRegion(value);
2447 return isa<VectorType>(result.getType());
2449 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
2450 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.