22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24#include "llvm/Support/FormatVariadic.h"
43 VectorType distributedType) {
49 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
54 distributedType.getContext());
62 VectorType distributedType) {
63 assert(sequentialType.getRank() == distributedType.getRank() &&
64 "sequential and distributed vector types must have the same rank");
66 for (
int64_t i = 0; i < sequentialType.getRank(); ++i) {
67 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
70 assert(distributedDim == -1 &&
"found multiple distributed dims");
74 return distributedDim;
84struct DistributedLoadStoreHelper {
85 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
86 Value laneId, Value zero)
87 : sequentialVal(sequentialVal), distributedVal(distributedVal),
88 laneId(laneId), zero(zero) {
89 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
90 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
91 if (sequentialVectorType && distributedVectorType)
96 Value buildDistributedOffset(RewriterBase &
b, Location loc, int64_t index) {
97 int64_t distributedSize = distributedVectorType.getDimSize(index);
99 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
100 ArrayRef<Value>{laneId});
110 Operation *buildStore(RewriterBase &
b, Location loc, Value val,
112 assert((val == distributedVal || val == sequentialVal) &&
113 "Must store either the preregistered distributed or the "
114 "preregistered sequential value.");
116 if (!isa<VectorType>(val.
getType()))
117 return memref::StoreOp::create(
b, loc, val, buffer, zero);
121 int64_t rank = sequentialVectorType.getRank();
122 SmallVector<Value>
indices(rank, zero);
123 if (val == distributedVal) {
124 for (
auto dimExpr : distributionMap.getResults()) {
125 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126 indices[index] = buildDistributedOffset(
b, loc, index);
129 SmallVector<bool> inBounds(
indices.size(),
true);
130 return vector::TransferWriteOp::create(
132 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
155 Value buildLoad(RewriterBase &
b, Location loc, Type type, Value buffer) {
158 if (!isa<VectorType>(type))
159 return memref::LoadOp::create(
b, loc, buffer, zero);
164 assert((type == distributedVectorType || type == sequentialVectorType) &&
165 "Must store either the preregistered distributed or the "
166 "preregistered sequential type.");
167 SmallVector<Value>
indices(sequentialVectorType.getRank(), zero);
168 if (type == distributedVectorType) {
169 for (
auto dimExpr : distributionMap.getResults()) {
170 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171 indices[index] = buildDistributedOffset(
b, loc, index);
174 SmallVector<bool> inBounds(
indices.size(),
true);
175 return vector::TransferReadOp::create(
176 b, loc, cast<VectorType>(type), buffer,
indices,
178 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
181 Value sequentialVal, distributedVal, laneId, zero;
182 VectorType sequentialVectorType, distributedVectorType;
183 AffineMap distributionMap;
196 return rewriter.
create(res);
230 WarpOpToScfIfPattern(MLIRContext *context,
231 const WarpExecuteOnLane0LoweringOptions &options,
232 PatternBenefit benefit = 1)
233 : WarpDistributionPattern(context, benefit), options(options) {}
235 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
236 PatternRewriter &rewriter)
const override {
237 assert(warpOp.getBodyRegion().hasOneBlock() &&
238 "expected WarpOp with single block");
239 Block *warpOpBody = &warpOp.getBodyRegion().front();
240 Location loc = warpOp.getLoc();
243 OpBuilder::InsertionGuard g(rewriter);
248 Value isLane0 = arith::CmpIOp::create(
249 rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
250 auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
252 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
256 SmallVector<Value> bbArgReplacements;
257 for (
const auto &it : llvm::enumerate(warpOp.getArgs())) {
258 Value sequentialVal = warpOpBody->
getArgument(it.index());
259 Value distributedVal = it.value();
260 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
261 warpOp.getLaneid(), c0);
265 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
268 helper.buildStore(rewriter, loc, distributedVal, buffer);
271 bbArgReplacements.push_back(
272 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
276 if (!warpOp.getArgs().empty()) {
278 options.warpSyncronizationFn(loc, rewriter, warpOp);
282 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
288 SmallVector<Value> replacements;
289 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
290 Location yieldLoc = yieldOp.getLoc();
291 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
292 Value sequentialVal = it.value();
293 Value distributedVal = warpOp->getResult(it.index());
294 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
295 warpOp.getLaneid(), c0);
299 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
305 helper.buildStore(rewriter, loc, sequentialVal, buffer);
316 replacements.push_back(
317 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
321 if (!yieldOp.getOperands().empty()) {
323 options.warpSyncronizationFn(loc, rewriter, warpOp);
329 scf::YieldOp::create(rewriter, yieldLoc);
332 rewriter.
replaceOp(warpOp, replacements);
338 const WarpExecuteOnLane0LoweringOptions &options;
351static VectorType getDistributedType(VectorType originalType,
AffineMap map,
359 if (targetShape[position] % warpSize != 0) {
360 if (warpSize % targetShape[position] != 0) {
363 warpSize /= targetShape[position];
364 targetShape[position] = 1;
367 targetShape[position] = targetShape[position] / warpSize;
374 VectorType targetType =
375 VectorType::get(targetShape, originalType.getElementType());
385getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp,
Region &innerRegion,
387 llvm::SmallSetVector<Value, 32> escapingValues;
390 if (innerRegion.
empty())
391 return {std::move(escapingValues), std::move(escapingValueTypes),
392 std::move(escapingValueDistTypes)};
395 if (warpOp->isAncestor(parent)) {
396 if (!escapingValues.insert(operand->
get()))
399 if (
auto vecType = dyn_cast<VectorType>(distType)) {
401 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
403 escapingValueTypes.push_back(operand->
get().
getType());
404 escapingValueDistTypes.push_back(distType);
407 return {std::move(escapingValues), std::move(escapingValueTypes),
408 std::move(escapingValueDistTypes)};
432 unsigned maxNumElementsToExtract, PatternBenefit
b = 1)
433 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)),
434 maxNumElementsToExtract(maxNumElementsToExtract) {}
438 LogicalResult tryDistributeOp(RewriterBase &rewriter,
439 vector::TransferWriteOp writeOp,
440 WarpExecuteOnLane0Op warpOp)
const {
441 VectorType writtenVectorType = writeOp.getVectorType();
445 if (writtenVectorType.getRank() == 0)
449 AffineMap map = distributionMapFn(writeOp.getVector());
450 VectorType targetType =
451 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
457 if (writeOp.getMask()) {
464 if (!writeOp.getPermutationMap().isMinorIdentity())
467 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
472 vector::TransferWriteOp newWriteOp =
473 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
477 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
483 SmallVector<OpFoldResult> delinearizedIdSizes;
484 for (
auto [seqSize, distSize] :
485 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
486 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
487 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
489 SmallVector<Value> delinearized;
491 delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
492 rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
498 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
501 AffineMap indexMap = map.
compose(newWriteOp.getPermutationMap());
502 Location loc = newWriteOp.getLoc();
503 SmallVector<Value>
indices(newWriteOp.getIndices().begin(),
504 newWriteOp.getIndices().end());
507 bindDims(newWarpOp.getContext(), d0, d1);
508 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
511 unsigned indexPos = indexExpr.getPosition();
512 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
513 Value laneId = delinearized[vectorPos];
517 rewriter, loc, d0 + scale * d1, {
indices[indexPos], laneId});
519 newWriteOp.getIndicesMutable().assign(
indices);
525 LogicalResult tryExtractOp(RewriterBase &rewriter,
526 vector::TransferWriteOp writeOp,
527 WarpExecuteOnLane0Op warpOp)
const {
528 Location loc = writeOp.getLoc();
529 VectorType vecType = writeOp.getVectorType();
531 if (vecType.getNumElements() > maxNumElementsToExtract) {
535 "writes more elements ({0}) than allowed to extract ({1})",
536 vecType.getNumElements(), maxNumElementsToExtract));
540 if (llvm::all_of(warpOp.getOps(),
541 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
544 SmallVector<Value> yieldValues = {writeOp.getVector()};
545 SmallVector<Type> retTypes = {vecType};
546 SmallVector<size_t> newRetIndices;
548 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
552 auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc,
TypeRange(),
553 newWarpOp.getLaneid(),
554 newWarpOp.getWarpSize());
555 Block &body = secondWarpOp.getBodyRegion().front();
558 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
559 newWriteOp.getValueToStoreMutable().assign(
560 newWarpOp.getResult(newRetIndices[0]));
562 gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
567 PatternRewriter &rewriter)
const override {
568 gpu::YieldOp yield = warpOp.getTerminator();
569 Operation *lastNode = yield->getPrevNode();
570 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
574 Value maybeMask = writeOp.getMask();
575 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
576 return writeOp.getVector() == value ||
577 (maybeMask && maybeMask == value) ||
578 warpOp.isDefinedOutsideOfRegion(value);
582 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
586 if (writeOp.getMask())
589 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
599 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
600 WarpExecuteOnLane0Op warpOp,
601 vector::TransferWriteOp writeOp,
602 VectorType targetType,
603 VectorType maybeMaskType)
const {
604 assert(writeOp->getParentOp() == warpOp &&
605 "write must be nested immediately under warp");
606 OpBuilder::InsertionGuard g(rewriter);
607 SmallVector<size_t> newRetIndices;
608 WarpExecuteOnLane0Op newWarpOp;
611 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
612 TypeRange{targetType, maybeMaskType}, newRetIndices);
615 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
620 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
622 newWriteOp.getValueToStoreMutable().assign(
623 newWarpOp.getResult(newRetIndices[0]));
625 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
630 unsigned maxNumElementsToExtract = 1;
653 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
654 PatternRewriter &rewriter)
const override {
655 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
663 Value distributedVal = warpOp.getResult(operandIndex);
664 SmallVector<Value> yieldValues;
665 SmallVector<Type> retTypes;
666 Location loc = warpOp.getLoc();
669 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
671 auto operandType = cast<VectorType>(operand.
get().
getType());
673 VectorType::get(vecType.getShape(), operandType.getElementType());
676 assert(!isa<VectorType>(operandType) &&
677 "unexpected yield of vector from op with scalar result type");
678 targetType = operandType;
680 retTypes.push_back(targetType);
681 yieldValues.push_back(operand.
get());
683 SmallVector<size_t> newRetIndices;
684 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
685 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
687 SmallVector<Value> newOperands(elementWise->
getOperands().begin(),
689 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
690 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
692 OpBuilder::InsertionGuard g(rewriter);
695 rewriter, loc, elementWise, newOperands,
696 {newWarpOp.getResult(operandIndex).getType()});
720 PatternRewriter &rewriter)
const override {
721 OpOperand *yieldOperand =
726 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
733 Attribute scalarAttr = dense.getSplatValue<Attribute>();
735 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
736 Location loc = warpOp.getLoc();
738 Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
767 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
768 PatternRewriter &rewriter)
const override {
769 OpOperand *yieldOperand =
770 getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
776 if (resTy.getNumElements() !=
static_cast<int64_t
>(warpOp.getWarpSize()))
779 llvm::formatv(
"Expected result size ({0}) to be of warp size ({1})",
780 resTy.getNumElements(), warpOp.getWarpSize()));
781 VectorType newVecTy =
782 cast<VectorType>(warpOp.getResult(operandIdx).getType());
784 Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
785 newVecTy, warpOp.getLaneid());
812 PatternRewriter &rewriter)
const override {
816 OpOperand *operand =
getWarpResult(warpOp, [](Operation *op) {
818 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
822 warpOp,
"warp result is not a vector.transfer_read op");
826 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
828 read,
"source must be defined outside of the region");
831 Value distributedVal = warpOp.getResult(operandIndex);
833 SmallVector<Value, 4>
indices(read.getIndices().begin(),
834 read.getIndices().end());
835 auto sequentialType = cast<VectorType>(read.getResult().getType());
836 auto distributedType = cast<VectorType>(distributedVal.
getType());
838 AffineMap indexMap = map.
compose(read.getPermutationMap());
842 SmallVector<Value> delinearizedIds;
844 distributedType.getShape(), warpOp.getWarpSize(),
845 warpOp.getLaneid(), delinearizedIds)) {
847 read,
"cannot delinearize lane ID for distribution");
849 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
852 OpBuilder::InsertionGuard g(rewriter);
853 SmallVector<Value> additionalResults(
indices.begin(),
indices.end());
854 SmallVector<Type> additionalResultTypes(
indices.size(),
856 additionalResults.push_back(read.getPadding());
857 additionalResultTypes.push_back(read.getPadding().getType());
859 bool hasMask =
false;
860 if (read.getMask()) {
870 read,
"non-trivial permutation maps not supported");
871 VectorType maskType =
872 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
873 additionalResults.push_back(read.getMask());
874 additionalResultTypes.push_back(maskType);
877 SmallVector<size_t> newRetIndices;
879 rewriter, warpOp, additionalResults, additionalResultTypes,
881 distributedVal = newWarpOp.getResult(operandIndex);
884 SmallVector<Value> newIndices;
885 for (int64_t i = 0, e =
indices.size(); i < e; ++i)
886 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
891 bindDims(read.getContext(), d0, d1);
892 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
895 unsigned indexPos = indexExpr.getPosition();
896 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
897 int64_t scale = distributedType.getDimSize(vectorPos);
899 rewriter, read.getLoc(), d0 + scale * d1,
900 {newIndices[indexPos], delinearizedIds[vectorPos]});
904 Value newPadding = newWarpOp.getResult(newRetIndices[
indices.size()]);
907 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
909 auto newRead = vector::TransferReadOp::create(
910 rewriter, read.getLoc(), distributedVal.
getType(), read.getBase(),
911 newIndices, read.getPermutationMapAttr(), newPadding, newMask,
912 read.getInBoundsAttr());
924 PatternRewriter &rewriter)
const override {
925 SmallVector<Type> newResultTypes;
926 newResultTypes.reserve(warpOp->getNumResults());
927 SmallVector<Value> newYieldValues;
928 newYieldValues.reserve(warpOp->getNumResults());
931 gpu::YieldOp yield = warpOp.getTerminator();
942 for (OpResult
result : warpOp.getResults()) {
945 Value yieldOperand = yield.getOperand(
result.getResultNumber());
946 auto it = dedupYieldOperandPositionMap.insert(
947 std::make_pair(yieldOperand, newResultTypes.size()));
948 dedupResultPositionMap.insert(std::make_pair(
result, it.first->second));
951 newResultTypes.push_back(
result.getType());
952 newYieldValues.push_back(yieldOperand);
955 if (yield.getNumOperands() == newYieldValues.size())
959 rewriter, warpOp, newYieldValues, newResultTypes);
962 newWarpOp.getBody()->walk([&](Operation *op) {
968 SmallVector<Value> newValues;
969 newValues.reserve(warpOp->getNumResults());
970 for (OpResult
result : warpOp.getResults()) {
972 newValues.push_back(Value());
975 newWarpOp.getResult(dedupResultPositionMap.lookup(
result)));
987 PatternRewriter &rewriter)
const override {
988 gpu::YieldOp yield = warpOp.getTerminator();
990 unsigned resultIndex;
991 for (OpOperand &operand : yield->getOpOperands()) {
1000 valForwarded = operand.
get();
1004 auto arg = dyn_cast<BlockArgument>(operand.
get());
1005 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1007 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1010 valForwarded = warpOperand;
1028 PatternRewriter &rewriter)
const override {
1029 OpOperand *operand =
1035 Location loc = broadcastOp.getLoc();
1037 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1038 Value broadcastSrc = broadcastOp.getSource();
1039 Type broadcastSrcType = broadcastSrc.
getType();
1046 vector::BroadcastableToResult::Success)
1048 SmallVector<size_t> newRetIndices;
1050 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1052 Value broadcasted = vector::BroadcastOp::create(
1053 rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1065 PatternRewriter &rewriter)
const override {
1066 OpOperand *operand =
1074 auto castDistributedType =
1075 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1076 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1077 VectorType castResultType = castDistributedType;
1081 unsigned castDistributedRank = castDistributedType.getRank();
1082 unsigned castOriginalRank = castOriginalType.getRank();
1083 if (castDistributedRank < castOriginalRank) {
1084 SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1085 llvm::append_range(shape, castDistributedType.getShape());
1086 castDistributedType =
1087 VectorType::get(shape, castDistributedType.getElementType());
1090 SmallVector<size_t> newRetIndices;
1092 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1095 Value newCast = vector::ShapeCastOp::create(
1096 rewriter, oldCastOp.getLoc(), castResultType,
1097 newWarpOp->getResult(newRetIndices[0]));
1123template <
typename OpType,
1124 typename = std::enable_if_t<llvm::is_one_of<
1125 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
1128 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1129 PatternRewriter &rewriter)
const override {
1130 OpOperand *yieldOperand = getWarpResult(warpOp, (llvm::IsaPred<OpType>));
1139 !llvm::all_of(mask->
getOperands(), [&](Value value) {
1140 return warpOp.isDefinedOutsideOfRegion(value);
1144 Location loc = mask->
getLoc();
1147 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1149 ArrayRef<int64_t> seqShape = seqType.getShape();
1150 ArrayRef<int64_t> distShape = distType.getShape();
1151 SmallVector<Value> materializedOperands;
1152 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
1153 materializedOperands.append(mask->
getOperands().begin(),
1156 auto constantMaskOp = cast<vector::ConstantMaskOp>(mask);
1157 auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
1158 for (
auto dimSize : dimSizes)
1159 materializedOperands.push_back(
1166 SmallVector<Value> delinearizedIds;
1167 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1168 warpOp.getWarpSize(), warpOp.getLaneid(),
1171 mask,
"cannot delinearize lane ID for distribution");
1172 assert(!delinearizedIds.empty());
1180 SmallVector<Value> newOperands;
1181 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1188 rewriter, loc, s1 - s0 * distShape[i],
1189 {delinearizedIds[i], materializedOperands[i]});
1190 newOperands.push_back(maskDimIdx);
1194 vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1229 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1230 PatternRewriter &rewriter)
const override {
1231 OpOperand *operand =
1232 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1238 auto distributedType =
1239 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1242 if (distributedType.getRank() < 2)
1244 insertOp,
"result vector type must be 2D or higher");
1247 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1248 int64_t destDistributedDim =
1250 assert(destDistributedDim != -1 &&
"could not find distributed dimension");
1252 VectorType srcType = insertOp.getSourceVectorType();
1253 VectorType destType = insertOp.getDestVectorType();
1258 int64_t sourceDistributedDim =
1259 destDistributedDim - (destType.getRank() - srcType.getRank());
1260 if (sourceDistributedDim < 0)
1263 "distributed dimension must be in the last k dims of dest vector");
1265 if (srcType.getDimSize(sourceDistributedDim) !=
1266 destType.getDimSize(destDistributedDim))
1268 insertOp,
"distributed dimension must be fully inserted");
1269 SmallVector<int64_t> newSourceDistShape(
1270 insertOp.getSourceVectorType().getShape());
1271 newSourceDistShape[sourceDistributedDim] =
1272 distributedType.getDimSize(destDistributedDim);
1274 VectorType::get(newSourceDistShape, distributedType.getElementType());
1275 VectorType newDestTy = distributedType;
1276 SmallVector<size_t> newRetIndices;
1277 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1278 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1279 {newSourceTy, newDestTy}, newRetIndices);
1281 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1282 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1285 Value newInsert = vector::InsertStridedSliceOp::create(
1286 rewriter, insertOp.getLoc(), distributedDest.
getType(),
1287 distributedSource, distributedDest, insertOp.getOffsets(),
1288 insertOp.getStrides());
1318 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1319 PatternRewriter &rewriter)
const override {
1320 OpOperand *operand =
1321 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1327 auto distributedType =
1328 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1331 if (distributedType.getRank() < 2)
1333 extractOp,
"result vector type must be 2D or higher");
1336 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1338 assert(distributedDim != -1 &&
"could not find distributed dimension");
1340 int64_t numOfExtractedDims =
1341 static_cast<int64_t
>(extractOp.getSizes().size());
1348 if (distributedDim < numOfExtractedDims) {
1349 int64_t distributedDimOffset =
1350 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1352 int64_t distributedDimSize =
1353 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1355 if (distributedDimOffset != 0 ||
1356 distributedDimSize != yieldedType.getDimSize(distributedDim))
1358 extractOp,
"distributed dimension must be fully extracted");
1360 SmallVector<int64_t> newDistributedShape(
1361 extractOp.getSourceVectorType().getShape());
1362 newDistributedShape[distributedDim] =
1363 distributedType.getDimSize(distributedDim);
1364 auto newDistributedType =
1365 VectorType::get(newDistributedShape, distributedType.getElementType());
1366 SmallVector<size_t> newRetIndices;
1367 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1368 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1371 SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1372 extractOp.getSizes(), [](Attribute attr) { return attr; });
1374 if (distributedDim <
static_cast<int64_t
>(distributedSizes.size()))
1376 distributedType.getDimSize(distributedDim));
1380 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1381 Value newExtract = vector::ExtractStridedSliceOp::create(
1382 rewriter, extractOp.getLoc(), distributedType, distributedVec,
1383 extractOp.getOffsets(),
1384 ArrayAttr::get(rewriter.
getContext(), distributedSizes),
1385 extractOp.getStrides());
1396 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1397 PatternRewriter &rewriter)
const override {
1398 OpOperand *operand =
1399 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1404 VectorType extractSrcType = extractOp.getSourceVectorType();
1405 Location loc = extractOp.getLoc();
1408 if (extractSrcType.getRank() <= 1) {
1414 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1420 SmallVector<size_t> newRetIndices;
1421 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1422 rewriter, warpOp, {extractOp.getSource()},
1423 {extractOp.getSourceVectorType()}, newRetIndices);
1425 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1427 Value newExtract = vector::ExtractOp::create(
1428 rewriter, loc, distributedVec, extractOp.getMixedPosition());
1435 auto distributedType =
1436 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1437 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1439 assert(distributedDim != -1 &&
"could not find distributed dimension");
1440 (void)distributedDim;
1443 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1444 for (
int i = 0; i < distributedType.getRank(); ++i)
1445 newDistributedShape[i + extractOp.getNumIndices()] =
1446 distributedType.getDimSize(i);
1447 auto newDistributedType =
1448 VectorType::get(newDistributedShape, distributedType.getElementType());
1449 SmallVector<size_t> newRetIndices;
1450 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1451 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1454 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1456 Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1457 extractOp.getMixedPosition());
1467 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1468 PatternBenefit
b = 1)
1469 : WarpDistributionPattern(ctx,
b), warpShuffleFromIdxFn(std::move(fn)) {}
1470 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1471 PatternRewriter &rewriter)
const override {
1472 OpOperand *operand =
1473 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1478 VectorType extractSrcType = extractOp.getSourceVectorType();
1480 if (extractSrcType.getRank() > 1) {
1482 extractOp,
"only 0-D or 1-D source supported for now");
1486 if (!extractSrcType.getElementType().isF32() &&
1487 !extractSrcType.getElementType().isInteger(32))
1489 extractOp,
"only f32/i32 element types are supported");
1490 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1491 Type elType = extractSrcType.getElementType();
1492 VectorType distributedVecType;
1493 if (!is0dOrVec1Extract) {
1494 assert(extractSrcType.getRank() == 1 &&
1495 "expected that extract src rank is 0 or 1");
1496 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1498 int64_t elementsPerLane =
1499 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1500 distributedVecType = VectorType::get({elementsPerLane}, elType);
1502 distributedVecType = extractSrcType;
1505 SmallVector<Value> additionalResults{extractOp.getSource()};
1506 SmallVector<Type> additionalResultTypes{distributedVecType};
1507 additionalResults.append(
1508 SmallVector<Value>(extractOp.getDynamicPosition()));
1509 additionalResultTypes.append(
1510 SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1512 Location loc = extractOp.getLoc();
1513 SmallVector<size_t> newRetIndices;
1514 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1515 rewriter, warpOp, additionalResults, additionalResultTypes,
1518 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1522 if (is0dOrVec1Extract) {
1524 SmallVector<int64_t>
indices(extractSrcType.getRank(), 0);
1526 vector::ExtractOp::create(rewriter, loc, distributedVec,
indices);
1532 int64_t staticPos = extractOp.getStaticPosition()[0];
1533 OpFoldResult pos = ShapedType::isDynamic(staticPos)
1534 ? (newWarpOp->getResult(newRetIndices[1]))
1538 int64_t elementsPerLane = distributedVecType.getShape()[0];
1542 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1545 elementsPerLane == 1
1548 sym0 % elementsPerLane, pos);
1550 vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1553 Value shuffled = warpShuffleFromIdxFn(
1554 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1560 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1567 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1568 PatternRewriter &rewriter)
const override {
1569 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1574 VectorType vecType = insertOp.getDestVectorType();
1575 VectorType distrType =
1576 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1579 if (vecType.getRank() > 1) {
1581 insertOp,
"only 0-D or 1-D source supported for now");
1585 SmallVector<Value> additionalResults{insertOp.getDest(),
1586 insertOp.getValueToStore()};
1587 SmallVector<Type> additionalResultTypes{
1588 distrType, insertOp.getValueToStore().getType()};
1589 additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1590 additionalResultTypes.append(
1591 SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1593 Location loc = insertOp.getLoc();
1594 SmallVector<size_t> newRetIndices;
1595 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1596 rewriter, warpOp, additionalResults, additionalResultTypes,
1599 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1600 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1604 if (vecType.getRank() != 0) {
1605 int64_t staticPos = insertOp.getStaticPosition()[0];
1606 pos = ShapedType::isDynamic(staticPos)
1607 ? (newWarpOp->getResult(newRetIndices[2]))
1612 if (vecType == distrType) {
1614 SmallVector<OpFoldResult>
indices;
1618 newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1627 int64_t elementsPerLane = distrType.getShape()[0];
1631 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1634 rewriter, loc, sym0 % elementsPerLane, pos);
1635 Value isInsertingLane =
1636 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1637 newWarpOp.getLaneid(), insertingLane);
1640 rewriter, loc, isInsertingLane,
1642 [&](OpBuilder &builder, Location loc) {
1643 Value newInsert = vector::InsertOp::create(
1644 builder, loc, newSource, distributedVec, newPos);
1645 scf::YieldOp::create(builder, loc, newInsert);
1648 [&](OpBuilder &builder, Location loc) {
1649 scf::YieldOp::create(builder, loc, distributedVec);
1659 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1660 PatternRewriter &rewriter)
const override {
1661 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1666 Location loc = insertOp.getLoc();
1669 if (insertOp.getDestVectorType().getRank() <= 1) {
1675 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1678 SmallVector<size_t> newRetIndices;
1679 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1680 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1681 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1684 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1685 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1686 Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1688 insertOp.getMixedPosition());
1695 auto distrDestType =
1696 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1697 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1698 int64_t distrDestDim = -1;
1699 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1700 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1703 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1707 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1710 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1711 SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1718 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1719 if (distrSrcDim >= 0)
1720 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1722 VectorType::get(distrSrcShape, distrDestType.getElementType());
1725 SmallVector<size_t> newRetIndices;
1726 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1727 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1728 {distrSrcType, distrDestType}, newRetIndices);
1730 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1731 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1735 if (distrSrcDim >= 0) {
1737 newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1739 insertOp.getMixedPosition());
1742 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1743 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1747 rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1748 Value isInsertingLane =
1749 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1750 newWarpOp.getLaneid(), insertingLane);
1752 newPos[distrDestDim] %= elementsPerLane;
1753 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1754 Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1755 distributedDest, newPos);
1756 scf::YieldOp::create(builder, loc, newInsert);
1758 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1759 scf::YieldOp::create(builder, loc, distributedDest);
1761 newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1763 nonInsertingBuilder)
1800 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)) {}
1802 PatternRewriter &rewriter)
const override {
1803 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1805 Operation *lastNode = warpOpYield->getPrevNode();
1806 auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1817 SmallVector<Value> nonIfYieldValues;
1818 SmallVector<unsigned> nonIfYieldIndices;
1819 llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1820 llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1821 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1824 nonIfYieldValues.push_back(yieldOperand.
get());
1825 nonIfYieldIndices.push_back(yieldOperandIdx);
1828 OpResult ifResult = cast<OpResult>(yieldOperand.
get());
1830 ifResultMapping[yieldOperandIdx] = ifResultIdx;
1833 if (!isa<VectorType>(ifResult.
getType()))
1835 VectorType distType =
1836 cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1837 ifResultDistTypes[ifResultIdx] = distType;
1842 auto [escapingValuesThen, escapingValueInputTypesThen,
1843 escapingValueDistTypesThen] =
1844 getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1846 auto [escapingValuesElse, escapingValueInputTypesElse,
1847 escapingValueDistTypesElse] =
1848 getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1850 if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1851 llvm::is_contained(escapingValueDistTypesElse, Type{}))
1859 SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1860 newWarpOpYieldValues.append(escapingValuesThen.begin(),
1861 escapingValuesThen.end());
1862 newWarpOpYieldValues.append(escapingValuesElse.begin(),
1863 escapingValuesElse.end());
1864 SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1865 newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1866 escapingValueDistTypesThen.end());
1867 newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1868 escapingValueDistTypesElse.end());
1870 for (
auto [idx, val] :
1871 llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1872 newWarpOpYieldValues.push_back(val);
1873 newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1877 SmallVector<size_t> newIndices;
1879 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
1881 SmallVector<Type> newIfOpDistResTypes;
1882 for (
auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1883 Type distType = cast<Value>(res).getType();
1884 if (
auto vecType = dyn_cast<VectorType>(distType)) {
1885 AffineMap map = distributionMapFn(cast<Value>(res));
1887 distType = ifResultDistTypes.count(i)
1888 ? ifResultDistTypes[i]
1889 : getDistributedType(vecType, map, warpOp.getWarpSize());
1891 newIfOpDistResTypes.push_back(distType);
1894 OpBuilder::InsertionGuard g(rewriter);
1896 auto newIfOp = scf::IfOp::create(
1897 rewriter, ifOp.getLoc(), newIfOpDistResTypes,
1898 newWarpOp.getResult(newIndices[0]),
static_cast<bool>(ifOp.thenBlock()),
1899 static_cast<bool>(ifOp.elseBlock()));
1900 auto encloseRegionInWarpOp =
1902 llvm::SmallSetVector<Value, 32> &escapingValues,
1903 SmallVector<Type> &escapingValueInputTypes,
1904 size_t warpResRangeStart) {
1905 OpBuilder::InsertionGuard g(rewriter);
1909 llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1910 SmallVector<Value> innerWarpInputVals;
1911 SmallVector<Type> innerWarpInputTypes;
1912 for (
size_t i = 0; i < escapingValues.size();
1913 ++i, ++warpResRangeStart) {
1914 innerWarpInputVals.push_back(
1915 newWarpOp.getResult(newIndices[warpResRangeStart]));
1916 escapeValToBlockArgIndex[escapingValues[i]] =
1917 innerWarpInputTypes.size();
1918 innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1920 auto innerWarp = WarpExecuteOnLane0Op::create(
1921 rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1922 newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1923 innerWarpInputVals, innerWarpInputTypes);
1925 innerWarp.getWarpRegion().takeBody(*oldIfBranch->
getParent());
1926 innerWarp.getWarpRegion().addArguments(
1927 innerWarpInputTypes,
1928 SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1930 SmallVector<Value> yieldOperands;
1932 yieldOperands.push_back(operand);
1936 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1938 scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1942 innerWarp.walk([&](Operation *op) {
1944 auto it = escapeValToBlockArgIndex.find(operand.
get());
1945 if (it == escapeValToBlockArgIndex.end())
1947 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
1950 mlir::vector::moveScalarUniformCode(innerWarp);
1952 encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
1953 &newIfOp.getThenRegion().front(), escapingValuesThen,
1954 escapingValueInputTypesThen, 1);
1955 if (!ifOp.getElseRegion().empty())
1956 encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
1957 &newIfOp.getElseRegion().front(),
1958 escapingValuesElse, escapingValueInputTypesElse,
1959 1 + escapingValuesThen.size());
1962 for (
auto [origIdx, newIdx] : ifResultMapping)
1964 newIfOp.getResult(newIdx), newIfOp);
2007 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)) {}
2008 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2009 PatternRewriter &rewriter)
const override {
2010 gpu::YieldOp warpOpYield = warpOp.getTerminator();
2012 Operation *lastNode = warpOpYield->getPrevNode();
2013 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
2018 auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2019 getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2021 if (llvm::is_contained(escapingValueDistTypes, Type{}))
2032 SmallVector<Value> nonForYieldedValues;
2033 SmallVector<unsigned> nonForResultIndices;
2034 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
2035 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
2036 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
2039 nonForYieldedValues.push_back(yieldOperand.
get());
2043 OpResult forResult = cast<OpResult>(yieldOperand.
get());
2048 if (!isa<VectorType>(forResult.
getType()))
2050 VectorType distType = cast<VectorType>(
2052 forResultDistTypes[forResultNumber] = distType;
2060 SmallVector<Value> newWarpOpYieldValues;
2061 SmallVector<Type> newWarpOpDistTypes;
2062 newWarpOpYieldValues.insert(
2063 newWarpOpYieldValues.end(),
2064 {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
2065 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2066 {forOp.getLowerBound().getType(),
2067 forOp.getUpperBound().getType(),
2068 forOp.getStep().getType()});
2069 for (
auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
2070 newWarpOpYieldValues.push_back(initArg);
2072 Type distType = initArg.getType();
2073 if (
auto vecType = dyn_cast<VectorType>(distType)) {
2077 AffineMap map = distributionMapFn(initArg);
2078 distType = forResultDistTypes.count(i)
2079 ? forResultDistTypes[i]
2080 : getDistributedType(vecType, map, warpOp.getWarpSize());
2082 newWarpOpDistTypes.push_back(distType);
2085 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
2086 escapingValues.begin(), escapingValues.end());
2087 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2088 escapingValueDistTypes.begin(),
2089 escapingValueDistTypes.end());
2093 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2094 newWarpOpYieldValues.push_back(v);
2095 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
2098 SmallVector<size_t> newIndices;
2099 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2100 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
2104 const unsigned initArgsStartIdx = 3;
2105 const unsigned escapingValuesStartIdx =
2107 forOp.getInitArgs().size();
2109 SmallVector<Value> newForOpOperands;
2110 for (
size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
2111 newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
2114 OpBuilder::InsertionGuard g(rewriter);
2116 auto newForOp = scf::ForOp::create(
2117 rewriter, forOp.getLoc(),
2118 newWarpOp.getResult(newIndices[0]),
2119 newWarpOp.getResult(newIndices[1]),
2120 newWarpOp.getResult(newIndices[2]), newForOpOperands,
2121 nullptr, forOp.getUnsignedCmp());
2127 SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
2128 newForOp.getRegionIterArgs().end());
2129 SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
2130 forOp.getResultTypes().end());
2134 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
2135 for (
size_t i = escapingValuesStartIdx;
2136 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2137 innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
2138 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
2139 innerWarpInputType.size();
2140 innerWarpInputType.push_back(
2141 escapingValueInputTypes[i - escapingValuesStartIdx]);
2144 auto innerWarp = WarpExecuteOnLane0Op::create(
2145 rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
2146 newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
2147 innerWarpInputType);
2150 SmallVector<Value> argMapping;
2151 argMapping.push_back(newForOp.getInductionVar());
2152 for (Value args : innerWarp.getBody()->getArguments())
2153 argMapping.push_back(args);
2155 argMapping.resize(forOp.getBody()->getNumArguments());
2156 SmallVector<Value> yieldOperands;
2157 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
2158 yieldOperands.push_back(operand);
2160 rewriter.
eraseOp(forOp.getBody()->getTerminator());
2161 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
2166 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
2170 if (!innerWarp.getResults().empty())
2171 scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
2175 for (
auto [origIdx, newIdx] : forResultMapping)
2177 newForOp.getResult(newIdx), newForOp);
2180 newForOp.walk([&](Operation *op) {
2182 auto it = argIndexMapping.find(operand.
get());
2183 if (it == argIndexMapping.end())
2185 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
2190 mlir::vector::moveScalarUniformCode(innerWarp);
2218 WarpOpReduction(MLIRContext *context,
2219 DistributedReductionFn distributedReductionFn,
2220 PatternBenefit benefit = 1)
2221 : WarpDistributionPattern(context, benefit),
2222 distributedReductionFn(std::move(distributedReductionFn)) {}
2224 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2225 PatternRewriter &rewriter)
const override {
2226 OpOperand *yieldOperand =
2227 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
2233 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
2235 if (vectorType.getRank() != 1)
2237 warpOp,
"Only rank 1 reductions can be distributed.");
2239 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2241 warpOp,
"Reduction vector dimension must match was size.");
2242 if (!reductionOp.getType().isIntOrFloat())
2244 warpOp,
"Reduction distribution currently only supports floats and "
2247 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2250 SmallVector<Value> yieldValues = {reductionOp.getVector()};
2251 SmallVector<Type> retTypes = {
2252 VectorType::get({numElements}, reductionOp.getType())};
2253 if (reductionOp.getAcc()) {
2254 yieldValues.push_back(reductionOp.getAcc());
2255 retTypes.push_back(reductionOp.getAcc().getType());
2257 SmallVector<size_t> newRetIndices;
2258 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2259 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2263 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2266 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2267 reductionOp.getKind(), newWarpOp.getWarpSize());
2268 if (reductionOp.getAcc()) {
2270 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2271 newWarpOp.getResult(newRetIndices[1]));
2278 DistributedReductionFn distributedReductionFn;
2289void mlir::vector::populateDistributeTransferWriteOpPatterns(
2292 patterns.add<WarpOpTransferWrite>(
patterns.getContext(), distributionMapFn,
2293 maxNumElementsToExtract, benefit);
2296void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2298 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
2301 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2302 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
2303 WarpOpConstant, WarpOpInsertScalar, WarpOpInsert,
2304 WarpOpCreateMask<vector::CreateMaskOp>,
2305 WarpOpCreateMask<vector::ConstantMaskOp>,
2306 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2308 patterns.add<WarpOpExtractScalar>(
patterns.getContext(), warpShuffleFromIdxFn,
2316void mlir::vector::populateDistributeReduction(
2318 const DistributedReductionFn &distributedReductionFn,
2320 patterns.add<WarpOpReduction>(
patterns.getContext(), distributedReductionFn,
2327 return llvm::all_of(op->
getOperands(), definedOutside) &&
2331void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2332 Block *body = warpOp.getBody();
2335 llvm::SmallSetVector<Operation *, 8> opsToMove;
2338 auto isDefinedOutsideOfBody = [&](
Value value) {
2340 return (definingOp && opsToMove.count(definingOp)) ||
2341 warpOp.isDefinedOutsideOfRegion(value);
2348 return isa<VectorType>(result.getType());
2350 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
2351 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.