22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24#include "llvm/Support/FormatVariadic.h"
43 VectorType distributedType) {
49 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
54 distributedType.getContext());
62 VectorType distributedType) {
63 assert(sequentialType.getRank() == distributedType.getRank() &&
64 "sequential and distributed vector types must have the same rank");
66 for (
int64_t i = 0; i < sequentialType.getRank(); ++i) {
67 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
70 assert(distributedDim == -1 &&
"found multiple distributed dims");
74 return distributedDim;
84struct DistributedLoadStoreHelper {
85 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
86 Value laneId, Value zero)
87 : sequentialVal(sequentialVal), distributedVal(distributedVal),
88 laneId(laneId), zero(zero) {
89 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
90 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
91 if (sequentialVectorType && distributedVectorType)
96 Value buildDistributedOffset(RewriterBase &
b, Location loc, int64_t index) {
97 int64_t distributedSize = distributedVectorType.getDimSize(index);
99 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
100 ArrayRef<Value>{laneId});
110 Operation *buildStore(RewriterBase &
b, Location loc, Value val,
112 assert((val == distributedVal || val == sequentialVal) &&
113 "Must store either the preregistered distributed or the "
114 "preregistered sequential value.");
116 if (!isa<VectorType>(val.
getType()))
117 return memref::StoreOp::create(
b, loc, val, buffer, zero);
121 int64_t rank = sequentialVectorType.getRank();
122 SmallVector<Value>
indices(rank, zero);
123 if (val == distributedVal) {
124 for (
auto dimExpr : distributionMap.getResults()) {
125 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126 indices[index] = buildDistributedOffset(
b, loc, index);
129 SmallVector<bool> inBounds(
indices.size(),
true);
130 return vector::TransferWriteOp::create(
132 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
155 Value buildLoad(RewriterBase &
b, Location loc, Type type, Value buffer) {
158 if (!isa<VectorType>(type))
159 return memref::LoadOp::create(
b, loc, buffer, zero);
164 assert((type == distributedVectorType || type == sequentialVectorType) &&
165 "Must store either the preregistered distributed or the "
166 "preregistered sequential type.");
167 SmallVector<Value>
indices(sequentialVectorType.getRank(), zero);
168 if (type == distributedVectorType) {
169 for (
auto dimExpr : distributionMap.getResults()) {
170 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171 indices[index] = buildDistributedOffset(
b, loc, index);
174 SmallVector<bool> inBounds(
indices.size(),
true);
175 return vector::TransferReadOp::create(
176 b, loc, cast<VectorType>(type), buffer,
indices,
178 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
181 Value sequentialVal, distributedVal, laneId, zero;
182 VectorType sequentialVectorType, distributedVectorType;
183 AffineMap distributionMap;
196 return rewriter.
create(res);
230 WarpOpToScfIfPattern(MLIRContext *context,
231 const WarpExecuteOnLane0LoweringOptions &options,
232 PatternBenefit benefit = 1)
233 : WarpDistributionPattern(context, benefit), options(options) {}
235 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
236 PatternRewriter &rewriter)
const override {
237 assert(warpOp.getBodyRegion().hasOneBlock() &&
238 "expected WarpOp with single block");
239 Block *warpOpBody = &warpOp.getBodyRegion().front();
240 Location loc = warpOp.getLoc();
243 OpBuilder::InsertionGuard g(rewriter);
248 Value isLane0 = arith::CmpIOp::create(
249 rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
250 auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
252 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
256 SmallVector<Value> bbArgReplacements;
257 for (
const auto &it : llvm::enumerate(warpOp.getArgs())) {
258 Value sequentialVal = warpOpBody->
getArgument(it.index());
259 Value distributedVal = it.value();
260 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
261 warpOp.getLaneid(), c0);
265 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
268 helper.buildStore(rewriter, loc, distributedVal, buffer);
271 bbArgReplacements.push_back(
272 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
276 if (!warpOp.getArgs().empty()) {
278 options.warpSyncronizationFn(loc, rewriter, warpOp);
282 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
288 SmallVector<Value> replacements;
289 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
290 Location yieldLoc = yieldOp.getLoc();
291 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
292 Value sequentialVal = it.value();
293 Value distributedVal = warpOp->getResult(it.index());
294 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
295 warpOp.getLaneid(), c0);
299 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
305 helper.buildStore(rewriter, loc, sequentialVal, buffer);
316 replacements.push_back(
317 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
321 if (!yieldOp.getOperands().empty()) {
323 options.warpSyncronizationFn(loc, rewriter, warpOp);
329 scf::YieldOp::create(rewriter, yieldLoc);
332 rewriter.
replaceOp(warpOp, replacements);
338 const WarpExecuteOnLane0LoweringOptions &options;
351static VectorType getDistributedType(VectorType originalType,
AffineMap map,
359 if (targetShape[position] % warpSize != 0) {
360 if (warpSize % targetShape[position] != 0) {
363 warpSize /= targetShape[position];
364 targetShape[position] = 1;
367 targetShape[position] = targetShape[position] / warpSize;
374 VectorType targetType =
375 VectorType::get(targetShape, originalType.getElementType());
385getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp,
Region &innerRegion,
387 llvm::SmallSetVector<Value, 32> escapingValues;
390 if (innerRegion.
empty())
391 return {std::move(escapingValues), std::move(escapingValueTypes),
392 std::move(escapingValueDistTypes)};
395 if (warpOp->isAncestor(parent)) {
396 if (!escapingValues.insert(operand->
get()))
399 if (
auto vecType = dyn_cast<VectorType>(distType)) {
401 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
403 escapingValueTypes.push_back(operand->
get().
getType());
404 escapingValueDistTypes.push_back(distType);
407 return {std::move(escapingValues), std::move(escapingValueTypes),
408 std::move(escapingValueDistTypes)};
432 unsigned maxNumElementsToExtract, PatternBenefit
b = 1)
433 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)),
434 maxNumElementsToExtract(maxNumElementsToExtract) {}
438 LogicalResult tryDistributeOp(RewriterBase &rewriter,
439 vector::TransferWriteOp writeOp,
440 WarpExecuteOnLane0Op warpOp)
const {
441 VectorType writtenVectorType = writeOp.getVectorType();
445 if (writtenVectorType.getRank() == 0)
449 AffineMap map = distributionMapFn(writeOp.getVector());
450 VectorType targetType =
451 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
457 if (writeOp.getMask()) {
464 if (!writeOp.getPermutationMap().isMinorIdentity())
467 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
472 vector::TransferWriteOp newWriteOp =
473 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
477 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
483 SmallVector<OpFoldResult> delinearizedIdSizes;
484 for (
auto [seqSize, distSize] :
485 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
486 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
487 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
489 SmallVector<Value> delinearized;
491 delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
492 rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
498 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
501 AffineMap indexMap = map.
compose(newWriteOp.getPermutationMap());
502 Location loc = newWriteOp.getLoc();
503 SmallVector<Value>
indices(newWriteOp.getIndices().begin(),
504 newWriteOp.getIndices().end());
507 bindDims(newWarpOp.getContext(), d0, d1);
508 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
511 unsigned indexPos = indexExpr.getPosition();
512 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
513 Value laneId = delinearized[vectorPos];
517 rewriter, loc, d0 + scale * d1, {
indices[indexPos], laneId});
519 newWriteOp.getIndicesMutable().assign(
indices);
525 LogicalResult tryExtractOp(RewriterBase &rewriter,
526 vector::TransferWriteOp writeOp,
527 WarpExecuteOnLane0Op warpOp)
const {
528 Location loc = writeOp.getLoc();
529 VectorType vecType = writeOp.getVectorType();
531 if (vecType.getNumElements() > maxNumElementsToExtract) {
535 "writes more elements ({0}) than allowed to extract ({1})",
536 vecType.getNumElements(), maxNumElementsToExtract));
540 if (llvm::all_of(warpOp.getOps(),
541 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
544 SmallVector<Value> yieldValues = {writeOp.getVector()};
545 SmallVector<Type> retTypes = {vecType};
546 SmallVector<size_t> newRetIndices;
548 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
552 auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc,
TypeRange(),
553 newWarpOp.getLaneid(),
554 newWarpOp.getWarpSize());
555 Block &body = secondWarpOp.getBodyRegion().front();
558 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
559 newWriteOp.getValueToStoreMutable().assign(
560 newWarpOp.getResult(newRetIndices[0]));
562 gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
567 PatternRewriter &rewriter)
const override {
568 gpu::YieldOp yield = warpOp.getTerminator();
569 Operation *lastNode = yield->getPrevNode();
570 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
574 Value maybeMask = writeOp.getMask();
575 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
576 return writeOp.getVector() == value ||
577 (maybeMask && maybeMask == value) ||
578 warpOp.isDefinedOutsideOfRegion(value);
582 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
586 if (writeOp.getMask())
589 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
599 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
600 WarpExecuteOnLane0Op warpOp,
601 vector::TransferWriteOp writeOp,
602 VectorType targetType,
603 VectorType maybeMaskType)
const {
604 assert(writeOp->getParentOp() == warpOp &&
605 "write must be nested immediately under warp");
606 OpBuilder::InsertionGuard g(rewriter);
607 SmallVector<size_t> newRetIndices;
608 WarpExecuteOnLane0Op newWarpOp;
611 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
612 TypeRange{targetType, maybeMaskType}, newRetIndices);
615 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
620 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
622 newWriteOp.getValueToStoreMutable().assign(
623 newWarpOp.getResult(newRetIndices[0]));
625 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
630 unsigned maxNumElementsToExtract = 1;
653 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
654 PatternRewriter &rewriter)
const override {
655 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
663 Value distributedVal = warpOp.getResult(operandIndex);
664 SmallVector<Value> yieldValues;
665 SmallVector<Type> retTypes;
666 Location loc = warpOp.getLoc();
669 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
671 auto operandType = cast<VectorType>(operand.
get().
getType());
673 VectorType::get(vecType.getShape(), operandType.getElementType());
676 assert(!isa<VectorType>(operandType) &&
677 "unexpected yield of vector from op with scalar result type");
678 targetType = operandType;
680 retTypes.push_back(targetType);
681 yieldValues.push_back(operand.
get());
683 SmallVector<size_t> newRetIndices;
684 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
685 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
687 SmallVector<Value> newOperands(elementWise->
getOperands().begin(),
689 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
690 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
692 OpBuilder::InsertionGuard g(rewriter);
695 rewriter, loc, elementWise, newOperands,
696 {newWarpOp.getResult(operandIndex).getType()});
720 PatternRewriter &rewriter)
const override {
721 OpOperand *yieldOperand =
726 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
733 Attribute scalarAttr = dense.getSplatValue<Attribute>();
735 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
736 Location loc = warpOp.getLoc();
738 Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
767 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
768 PatternRewriter &rewriter)
const override {
769 OpOperand *yieldOperand =
770 getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
776 if (resTy.getNumElements() !=
static_cast<int64_t
>(warpOp.getWarpSize()))
779 llvm::formatv(
"Expected result size ({0}) to be of warp size ({1})",
780 resTy.getNumElements(), warpOp.getWarpSize()));
781 VectorType newVecTy =
782 cast<VectorType>(warpOp.getResult(operandIdx).getType());
784 Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
785 newVecTy, warpOp.getLaneid());
812 PatternRewriter &rewriter)
const override {
816 OpOperand *operand =
getWarpResult(warpOp, [](Operation *op) {
818 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
822 warpOp,
"warp result is not a vector.transfer_read op");
826 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
828 read,
"source must be defined outside of the region");
831 Value distributedVal = warpOp.getResult(operandIndex);
833 SmallVector<Value, 4>
indices(read.getIndices().begin(),
834 read.getIndices().end());
835 auto sequentialType = cast<VectorType>(read.getResult().getType());
836 auto distributedType = cast<VectorType>(distributedVal.
getType());
838 AffineMap indexMap = map.
compose(read.getPermutationMap());
842 SmallVector<Value> delinearizedIds;
844 distributedType.getShape(), warpOp.getWarpSize(),
845 warpOp.getLaneid(), delinearizedIds)) {
847 read,
"cannot delinearize lane ID for distribution");
849 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
852 OpBuilder::InsertionGuard g(rewriter);
853 SmallVector<Value> additionalResults(
indices.begin(),
indices.end());
854 SmallVector<Type> additionalResultTypes(
indices.size(),
856 additionalResults.push_back(read.getPadding());
857 additionalResultTypes.push_back(read.getPadding().getType());
859 bool hasMask =
false;
860 if (read.getMask()) {
870 read,
"non-trivial permutation maps not supported");
871 VectorType maskType =
872 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
873 additionalResults.push_back(read.getMask());
874 additionalResultTypes.push_back(maskType);
877 SmallVector<size_t> newRetIndices;
879 rewriter, warpOp, additionalResults, additionalResultTypes,
881 distributedVal = newWarpOp.getResult(operandIndex);
884 SmallVector<Value> newIndices;
885 for (int64_t i = 0, e =
indices.size(); i < e; ++i)
886 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
891 bindDims(read.getContext(), d0, d1);
892 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
895 unsigned indexPos = indexExpr.getPosition();
896 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
897 int64_t scale = distributedType.getDimSize(vectorPos);
899 rewriter, read.getLoc(), d0 + scale * d1,
900 {newIndices[indexPos], delinearizedIds[vectorPos]});
904 Value newPadding = newWarpOp.getResult(newRetIndices[
indices.size()]);
907 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
909 auto newRead = vector::TransferReadOp::create(
910 rewriter, read.getLoc(), distributedVal.
getType(), read.getBase(),
911 newIndices, read.getPermutationMapAttr(), newPadding, newMask,
912 read.getInBoundsAttr());
924 PatternRewriter &rewriter)
const override {
925 SmallVector<Type> newResultTypes;
926 newResultTypes.reserve(warpOp->getNumResults());
927 SmallVector<Value> newYieldValues;
928 newYieldValues.reserve(warpOp->getNumResults());
931 gpu::YieldOp yield = warpOp.getTerminator();
942 for (OpResult
result : warpOp.getResults()) {
945 Value yieldOperand = yield.getOperand(
result.getResultNumber());
946 auto it = dedupYieldOperandPositionMap.insert(
947 std::make_pair(yieldOperand, newResultTypes.size()));
948 dedupResultPositionMap.insert(std::make_pair(
result, it.first->second));
951 newResultTypes.push_back(
result.getType());
952 newYieldValues.push_back(yieldOperand);
955 if (yield.getNumOperands() == newYieldValues.size())
959 rewriter, warpOp, newYieldValues, newResultTypes);
962 newWarpOp.getBody()->walk([&](Operation *op) {
968 SmallVector<Value> newValues;
969 newValues.reserve(warpOp->getNumResults());
970 for (OpResult
result : warpOp.getResults()) {
972 newValues.push_back(Value());
975 newWarpOp.getResult(dedupResultPositionMap.lookup(
result)));
987 PatternRewriter &rewriter)
const override {
988 gpu::YieldOp yield = warpOp.getTerminator();
990 unsigned resultIndex;
991 for (OpOperand &operand : yield->getOpOperands()) {
1000 valForwarded = operand.
get();
1004 auto arg = dyn_cast<BlockArgument>(operand.
get());
1005 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1007 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1010 valForwarded = warpOperand;
1028 PatternRewriter &rewriter)
const override {
1029 OpOperand *operand =
1035 Location loc = broadcastOp.getLoc();
1037 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1038 Value broadcastSrc = broadcastOp.getSource();
1039 Type broadcastSrcType = broadcastSrc.
getType();
1046 vector::BroadcastableToResult::Success)
1048 SmallVector<size_t> newRetIndices;
1050 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1052 Value broadcasted = vector::BroadcastOp::create(
1053 rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1065 PatternRewriter &rewriter)
const override {
1066 OpOperand *operand =
1074 auto castDistributedType =
1075 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1076 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1077 VectorType castResultType = castDistributedType;
1081 unsigned castDistributedRank = castDistributedType.getRank();
1082 unsigned castOriginalRank = castOriginalType.getRank();
1083 if (castDistributedRank < castOriginalRank) {
1084 SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1085 llvm::append_range(shape, castDistributedType.getShape());
1086 castDistributedType =
1087 VectorType::get(shape, castDistributedType.getElementType());
1090 SmallVector<size_t> newRetIndices;
1092 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1095 Value newCast = vector::ShapeCastOp::create(
1096 rewriter, oldCastOp.getLoc(), castResultType,
1097 newWarpOp->getResult(newRetIndices[0]));
1123 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1124 PatternRewriter &rewriter)
const override {
1125 OpOperand *yieldOperand =
1126 getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1134 if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1135 return warpOp.isDefinedOutsideOfRegion(value);
1139 Location loc = mask.getLoc();
1142 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1143 VectorType seqType = mask.getVectorType();
1144 ArrayRef<int64_t> seqShape = seqType.getShape();
1145 ArrayRef<int64_t> distShape = distType.getShape();
1150 SmallVector<Value> delinearizedIds;
1151 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1152 warpOp.getWarpSize(), warpOp.getLaneid(),
1155 mask,
"cannot delinearize lane ID for distribution");
1156 assert(!delinearizedIds.empty());
1164 SmallVector<Value> newOperands;
1165 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1172 rewriter, loc, s1 - s0 * distShape[i],
1173 {delinearizedIds[i], mask.getOperand(i)});
1174 newOperands.push_back(maskDimIdx);
1178 vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1213 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1214 PatternRewriter &rewriter)
const override {
1215 OpOperand *operand =
1216 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1222 auto distributedType =
1223 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1226 if (distributedType.getRank() < 2)
1228 insertOp,
"result vector type must be 2D or higher");
1231 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1232 int64_t destDistributedDim =
1234 assert(destDistributedDim != -1 &&
"could not find distributed dimension");
1236 VectorType srcType = insertOp.getSourceVectorType();
1237 VectorType destType = insertOp.getDestVectorType();
1242 int64_t sourceDistributedDim =
1243 destDistributedDim - (destType.getRank() - srcType.getRank());
1244 if (sourceDistributedDim < 0)
1247 "distributed dimension must be in the last k dims of dest vector");
1249 if (srcType.getDimSize(sourceDistributedDim) !=
1250 destType.getDimSize(destDistributedDim))
1252 insertOp,
"distributed dimension must be fully inserted");
1253 SmallVector<int64_t> newSourceDistShape(
1254 insertOp.getSourceVectorType().getShape());
1255 newSourceDistShape[sourceDistributedDim] =
1256 distributedType.getDimSize(destDistributedDim);
1258 VectorType::get(newSourceDistShape, distributedType.getElementType());
1259 VectorType newDestTy = distributedType;
1260 SmallVector<size_t> newRetIndices;
1261 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1262 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1263 {newSourceTy, newDestTy}, newRetIndices);
1265 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1266 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1269 Value newInsert = vector::InsertStridedSliceOp::create(
1270 rewriter, insertOp.getLoc(), distributedDest.
getType(),
1271 distributedSource, distributedDest, insertOp.getOffsets(),
1272 insertOp.getStrides());
1302 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1303 PatternRewriter &rewriter)
const override {
1304 OpOperand *operand =
1305 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1311 auto distributedType =
1312 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1315 if (distributedType.getRank() < 2)
1317 extractOp,
"result vector type must be 2D or higher");
1320 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1322 assert(distributedDim != -1 &&
"could not find distributed dimension");
1324 int64_t numOfExtractedDims =
1325 static_cast<int64_t
>(extractOp.getSizes().size());
1332 if (distributedDim < numOfExtractedDims) {
1333 int64_t distributedDimOffset =
1334 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1336 int64_t distributedDimSize =
1337 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1339 if (distributedDimOffset != 0 ||
1340 distributedDimSize != yieldedType.getDimSize(distributedDim))
1342 extractOp,
"distributed dimension must be fully extracted");
1344 SmallVector<int64_t> newDistributedShape(
1345 extractOp.getSourceVectorType().getShape());
1346 newDistributedShape[distributedDim] =
1347 distributedType.getDimSize(distributedDim);
1348 auto newDistributedType =
1349 VectorType::get(newDistributedShape, distributedType.getElementType());
1350 SmallVector<size_t> newRetIndices;
1351 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1352 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1355 SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1356 extractOp.getSizes(), [](Attribute attr) { return attr; });
1358 if (distributedDim <
static_cast<int64_t
>(distributedSizes.size()))
1360 distributedType.getDimSize(distributedDim));
1364 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1365 Value newExtract = vector::ExtractStridedSliceOp::create(
1366 rewriter, extractOp.getLoc(), distributedType, distributedVec,
1367 extractOp.getOffsets(),
1368 ArrayAttr::get(rewriter.
getContext(), distributedSizes),
1369 extractOp.getStrides());
1380 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1381 PatternRewriter &rewriter)
const override {
1382 OpOperand *operand =
1383 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1388 VectorType extractSrcType = extractOp.getSourceVectorType();
1389 Location loc = extractOp.getLoc();
1392 if (extractSrcType.getRank() <= 1) {
1398 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1404 SmallVector<size_t> newRetIndices;
1405 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1406 rewriter, warpOp, {extractOp.getSource()},
1407 {extractOp.getSourceVectorType()}, newRetIndices);
1409 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1411 Value newExtract = vector::ExtractOp::create(
1412 rewriter, loc, distributedVec, extractOp.getMixedPosition());
1419 auto distributedType =
1420 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1421 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1423 assert(distributedDim != -1 &&
"could not find distributed dimension");
1424 (void)distributedDim;
1427 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1428 for (
int i = 0; i < distributedType.getRank(); ++i)
1429 newDistributedShape[i + extractOp.getNumIndices()] =
1430 distributedType.getDimSize(i);
1431 auto newDistributedType =
1432 VectorType::get(newDistributedShape, distributedType.getElementType());
1433 SmallVector<size_t> newRetIndices;
1434 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1435 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1438 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1440 Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1441 extractOp.getMixedPosition());
1451 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1452 PatternBenefit
b = 1)
1453 : WarpDistributionPattern(ctx,
b), warpShuffleFromIdxFn(std::move(fn)) {}
1454 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1455 PatternRewriter &rewriter)
const override {
1456 OpOperand *operand =
1457 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1462 VectorType extractSrcType = extractOp.getSourceVectorType();
1464 if (extractSrcType.getRank() > 1) {
1466 extractOp,
"only 0-D or 1-D source supported for now");
1470 if (!extractSrcType.getElementType().isF32() &&
1471 !extractSrcType.getElementType().isInteger(32))
1473 extractOp,
"only f32/i32 element types are supported");
1474 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1475 Type elType = extractSrcType.getElementType();
1476 VectorType distributedVecType;
1477 if (!is0dOrVec1Extract) {
1478 assert(extractSrcType.getRank() == 1 &&
1479 "expected that extract src rank is 0 or 1");
1480 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1482 int64_t elementsPerLane =
1483 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1484 distributedVecType = VectorType::get({elementsPerLane}, elType);
1486 distributedVecType = extractSrcType;
1489 SmallVector<Value> additionalResults{extractOp.getSource()};
1490 SmallVector<Type> additionalResultTypes{distributedVecType};
1491 additionalResults.append(
1492 SmallVector<Value>(extractOp.getDynamicPosition()));
1493 additionalResultTypes.append(
1494 SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1496 Location loc = extractOp.getLoc();
1497 SmallVector<size_t> newRetIndices;
1498 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1499 rewriter, warpOp, additionalResults, additionalResultTypes,
1502 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1506 if (is0dOrVec1Extract) {
1508 SmallVector<int64_t>
indices(extractSrcType.getRank(), 0);
1510 vector::ExtractOp::create(rewriter, loc, distributedVec,
indices);
1516 int64_t staticPos = extractOp.getStaticPosition()[0];
1517 OpFoldResult pos = ShapedType::isDynamic(staticPos)
1518 ? (newWarpOp->getResult(newRetIndices[1]))
1522 int64_t elementsPerLane = distributedVecType.getShape()[0];
1526 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1529 elementsPerLane == 1
1532 sym0 % elementsPerLane, pos);
1534 vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1537 Value shuffled = warpShuffleFromIdxFn(
1538 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1544 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1551 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1552 PatternRewriter &rewriter)
const override {
1553 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1558 VectorType vecType = insertOp.getDestVectorType();
1559 VectorType distrType =
1560 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1563 if (vecType.getRank() > 1) {
1565 insertOp,
"only 0-D or 1-D source supported for now");
1569 SmallVector<Value> additionalResults{insertOp.getDest(),
1570 insertOp.getValueToStore()};
1571 SmallVector<Type> additionalResultTypes{
1572 distrType, insertOp.getValueToStore().getType()};
1573 additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1574 additionalResultTypes.append(
1575 SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1577 Location loc = insertOp.getLoc();
1578 SmallVector<size_t> newRetIndices;
1579 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1580 rewriter, warpOp, additionalResults, additionalResultTypes,
1583 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1584 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1588 if (vecType.getRank() != 0) {
1589 int64_t staticPos = insertOp.getStaticPosition()[0];
1590 pos = ShapedType::isDynamic(staticPos)
1591 ? (newWarpOp->getResult(newRetIndices[2]))
1596 if (vecType == distrType) {
1598 SmallVector<OpFoldResult>
indices;
1602 newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1611 int64_t elementsPerLane = distrType.getShape()[0];
1615 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1618 rewriter, loc, sym0 % elementsPerLane, pos);
1619 Value isInsertingLane =
1620 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1621 newWarpOp.getLaneid(), insertingLane);
1624 rewriter, loc, isInsertingLane,
1626 [&](OpBuilder &builder, Location loc) {
1627 Value newInsert = vector::InsertOp::create(
1628 builder, loc, newSource, distributedVec, newPos);
1629 scf::YieldOp::create(builder, loc, newInsert);
1632 [&](OpBuilder &builder, Location loc) {
1633 scf::YieldOp::create(builder, loc, distributedVec);
1643 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1644 PatternRewriter &rewriter)
const override {
1645 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1650 Location loc = insertOp.getLoc();
1653 if (insertOp.getDestVectorType().getRank() <= 1) {
1659 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1662 SmallVector<size_t> newRetIndices;
1663 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1664 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1665 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1668 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1669 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1670 Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1672 insertOp.getMixedPosition());
1679 auto distrDestType =
1680 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1681 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1682 int64_t distrDestDim = -1;
1683 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1684 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1687 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1691 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1694 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1695 SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1702 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1703 if (distrSrcDim >= 0)
1704 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1706 VectorType::get(distrSrcShape, distrDestType.getElementType());
1709 SmallVector<size_t> newRetIndices;
1710 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1711 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1712 {distrSrcType, distrDestType}, newRetIndices);
1714 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1715 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1719 if (distrSrcDim >= 0) {
1721 newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1723 insertOp.getMixedPosition());
1726 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1727 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1731 rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1732 Value isInsertingLane =
1733 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1734 newWarpOp.getLaneid(), insertingLane);
1736 newPos[distrDestDim] %= elementsPerLane;
1737 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1738 Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1739 distributedDest, newPos);
1740 scf::YieldOp::create(builder, loc, newInsert);
1742 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1743 scf::YieldOp::create(builder, loc, distributedDest);
1745 newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1747 nonInsertingBuilder)
1784 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)) {}
1786 PatternRewriter &rewriter)
const override {
1787 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1789 Operation *lastNode = warpOpYield->getPrevNode();
1790 auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1801 SmallVector<Value> nonIfYieldValues;
1802 SmallVector<unsigned> nonIfYieldIndices;
1803 llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1804 llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1805 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1808 nonIfYieldValues.push_back(yieldOperand.
get());
1809 nonIfYieldIndices.push_back(yieldOperandIdx);
1812 OpResult ifResult = cast<OpResult>(yieldOperand.
get());
1814 ifResultMapping[yieldOperandIdx] = ifResultIdx;
1817 if (!isa<VectorType>(ifResult.
getType()))
1819 VectorType distType =
1820 cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1821 ifResultDistTypes[ifResultIdx] = distType;
1826 auto [escapingValuesThen, escapingValueInputTypesThen,
1827 escapingValueDistTypesThen] =
1828 getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1830 auto [escapingValuesElse, escapingValueInputTypesElse,
1831 escapingValueDistTypesElse] =
1832 getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1834 if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1835 llvm::is_contained(escapingValueDistTypesElse, Type{}))
1843 SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1844 newWarpOpYieldValues.append(escapingValuesThen.begin(),
1845 escapingValuesThen.end());
1846 newWarpOpYieldValues.append(escapingValuesElse.begin(),
1847 escapingValuesElse.end());
1848 SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1849 newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1850 escapingValueDistTypesThen.end());
1851 newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1852 escapingValueDistTypesElse.end());
1854 for (
auto [idx, val] :
1855 llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1856 newWarpOpYieldValues.push_back(val);
1857 newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1861 SmallVector<size_t> newIndices;
1863 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
1865 SmallVector<Type> newIfOpDistResTypes;
1866 for (
auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1867 Type distType = cast<Value>(res).getType();
1868 if (
auto vecType = dyn_cast<VectorType>(distType)) {
1869 AffineMap map = distributionMapFn(cast<Value>(res));
1871 distType = ifResultDistTypes.count(i)
1872 ? ifResultDistTypes[i]
1873 : getDistributedType(vecType, map, warpOp.getWarpSize());
1875 newIfOpDistResTypes.push_back(distType);
1878 OpBuilder::InsertionGuard g(rewriter);
1880 auto newIfOp = scf::IfOp::create(
1881 rewriter, ifOp.getLoc(), newIfOpDistResTypes,
1882 newWarpOp.getResult(newIndices[0]),
static_cast<bool>(ifOp.thenBlock()),
1883 static_cast<bool>(ifOp.elseBlock()));
1884 auto encloseRegionInWarpOp =
1886 llvm::SmallSetVector<Value, 32> &escapingValues,
1887 SmallVector<Type> &escapingValueInputTypes,
1888 size_t warpResRangeStart) {
1889 OpBuilder::InsertionGuard g(rewriter);
1893 llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1894 SmallVector<Value> innerWarpInputVals;
1895 SmallVector<Type> innerWarpInputTypes;
1896 for (
size_t i = 0; i < escapingValues.size();
1897 ++i, ++warpResRangeStart) {
1898 innerWarpInputVals.push_back(
1899 newWarpOp.getResult(newIndices[warpResRangeStart]));
1900 escapeValToBlockArgIndex[escapingValues[i]] =
1901 innerWarpInputTypes.size();
1902 innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1904 auto innerWarp = WarpExecuteOnLane0Op::create(
1905 rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1906 newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1907 innerWarpInputVals, innerWarpInputTypes);
1909 innerWarp.getWarpRegion().takeBody(*oldIfBranch->
getParent());
1910 innerWarp.getWarpRegion().addArguments(
1911 innerWarpInputTypes,
1912 SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1914 SmallVector<Value> yieldOperands;
1916 yieldOperands.push_back(operand);
1920 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1922 scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1926 innerWarp.walk([&](Operation *op) {
1928 auto it = escapeValToBlockArgIndex.find(operand.
get());
1929 if (it == escapeValToBlockArgIndex.end())
1931 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
1934 mlir::vector::moveScalarUniformCode(innerWarp);
1936 encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
1937 &newIfOp.getThenRegion().front(), escapingValuesThen,
1938 escapingValueInputTypesThen, 1);
1939 if (!ifOp.getElseRegion().empty())
1940 encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
1941 &newIfOp.getElseRegion().front(),
1942 escapingValuesElse, escapingValueInputTypesElse,
1943 1 + escapingValuesThen.size());
1946 for (
auto [origIdx, newIdx] : ifResultMapping)
1948 newIfOp.getResult(newIdx), newIfOp);
1991 : WarpDistributionPattern(ctx,
b), distributionMapFn(std::move(fn)) {}
1992 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1993 PatternRewriter &rewriter)
const override {
1994 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1996 Operation *lastNode = warpOpYield->getPrevNode();
1997 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
2002 auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2003 getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2005 if (llvm::is_contained(escapingValueDistTypes, Type{}))
2016 SmallVector<Value> nonForYieldedValues;
2017 SmallVector<unsigned> nonForResultIndices;
2018 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
2019 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
2020 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
2023 nonForYieldedValues.push_back(yieldOperand.
get());
2027 OpResult forResult = cast<OpResult>(yieldOperand.
get());
2032 if (!isa<VectorType>(forResult.
getType()))
2034 VectorType distType = cast<VectorType>(
2036 forResultDistTypes[forResultNumber] = distType;
2044 SmallVector<Value> newWarpOpYieldValues;
2045 SmallVector<Type> newWarpOpDistTypes;
2046 newWarpOpYieldValues.insert(
2047 newWarpOpYieldValues.end(),
2048 {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
2049 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2050 {forOp.getLowerBound().getType(),
2051 forOp.getUpperBound().getType(),
2052 forOp.getStep().getType()});
2053 for (
auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
2054 newWarpOpYieldValues.push_back(initArg);
2056 Type distType = initArg.getType();
2057 if (
auto vecType = dyn_cast<VectorType>(distType)) {
2061 AffineMap map = distributionMapFn(initArg);
2062 distType = forResultDistTypes.count(i)
2063 ? forResultDistTypes[i]
2064 : getDistributedType(vecType, map, warpOp.getWarpSize());
2066 newWarpOpDistTypes.push_back(distType);
2069 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
2070 escapingValues.begin(), escapingValues.end());
2071 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2072 escapingValueDistTypes.begin(),
2073 escapingValueDistTypes.end());
2077 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2078 newWarpOpYieldValues.push_back(v);
2079 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
2082 SmallVector<size_t> newIndices;
2083 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2084 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
2088 const unsigned initArgsStartIdx = 3;
2089 const unsigned escapingValuesStartIdx =
2091 forOp.getInitArgs().size();
2093 SmallVector<Value> newForOpOperands;
2094 for (
size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
2095 newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
2098 OpBuilder::InsertionGuard g(rewriter);
2100 auto newForOp = scf::ForOp::create(
2101 rewriter, forOp.getLoc(),
2102 newWarpOp.getResult(newIndices[0]),
2103 newWarpOp.getResult(newIndices[1]),
2104 newWarpOp.getResult(newIndices[2]), newForOpOperands,
2105 nullptr, forOp.getUnsignedCmp());
2111 SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
2112 newForOp.getRegionIterArgs().end());
2113 SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
2114 forOp.getResultTypes().end());
2118 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
2119 for (
size_t i = escapingValuesStartIdx;
2120 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2121 innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
2122 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
2123 innerWarpInputType.size();
2124 innerWarpInputType.push_back(
2125 escapingValueInputTypes[i - escapingValuesStartIdx]);
2128 auto innerWarp = WarpExecuteOnLane0Op::create(
2129 rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
2130 newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
2131 innerWarpInputType);
2134 SmallVector<Value> argMapping;
2135 argMapping.push_back(newForOp.getInductionVar());
2136 for (Value args : innerWarp.getBody()->getArguments())
2137 argMapping.push_back(args);
2139 argMapping.resize(forOp.getBody()->getNumArguments());
2140 SmallVector<Value> yieldOperands;
2141 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
2142 yieldOperands.push_back(operand);
2144 rewriter.
eraseOp(forOp.getBody()->getTerminator());
2145 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
2150 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
2154 if (!innerWarp.getResults().empty())
2155 scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
2159 for (
auto [origIdx, newIdx] : forResultMapping)
2161 newForOp.getResult(newIdx), newForOp);
2164 newForOp.walk([&](Operation *op) {
2166 auto it = argIndexMapping.find(operand.
get());
2167 if (it == argIndexMapping.end())
2169 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
2174 mlir::vector::moveScalarUniformCode(innerWarp);
2202 WarpOpReduction(MLIRContext *context,
2203 DistributedReductionFn distributedReductionFn,
2204 PatternBenefit benefit = 1)
2205 : WarpDistributionPattern(context, benefit),
2206 distributedReductionFn(std::move(distributedReductionFn)) {}
2208 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2209 PatternRewriter &rewriter)
const override {
2210 OpOperand *yieldOperand =
2211 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
2217 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
2219 if (vectorType.getRank() != 1)
2221 warpOp,
"Only rank 1 reductions can be distributed.");
2223 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2225 warpOp,
"Reduction vector dimension must match was size.");
2226 if (!reductionOp.getType().isIntOrFloat())
2228 warpOp,
"Reduction distribution currently only supports floats and "
2231 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2234 SmallVector<Value> yieldValues = {reductionOp.getVector()};
2235 SmallVector<Type> retTypes = {
2236 VectorType::get({numElements}, reductionOp.getType())};
2237 if (reductionOp.getAcc()) {
2238 yieldValues.push_back(reductionOp.getAcc());
2239 retTypes.push_back(reductionOp.getAcc().getType());
2241 SmallVector<size_t> newRetIndices;
2242 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2243 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2247 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2250 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2251 reductionOp.getKind(), newWarpOp.getWarpSize());
2252 if (reductionOp.getAcc()) {
2254 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2255 newWarpOp.getResult(newRetIndices[1]));
2262 DistributedReductionFn distributedReductionFn;
2273void mlir::vector::populateDistributeTransferWriteOpPatterns(
2276 patterns.add<WarpOpTransferWrite>(
patterns.getContext(), distributionMapFn,
2277 maxNumElementsToExtract, benefit);
2280void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2282 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
2286 .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2287 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
2288 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
2289 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2291 patterns.add<WarpOpExtractScalar>(
patterns.getContext(), warpShuffleFromIdxFn,
2299void mlir::vector::populateDistributeReduction(
2301 const DistributedReductionFn &distributedReductionFn,
2303 patterns.add<WarpOpReduction>(
patterns.getContext(), distributedReductionFn,
2310 return llvm::all_of(op->
getOperands(), definedOutside) &&
2314void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2315 Block *body = warpOp.getBody();
2318 llvm::SmallSetVector<Operation *, 8> opsToMove;
2321 auto isDefinedOutsideOfBody = [&](
Value value) {
2323 return (definingOp && opsToMove.count(definingOp)) ||
2324 warpOp.isDefinedOutsideOfRegion(value);
2331 return isa<VectorType>(result.getType());
2333 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
2334 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.