20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/FormatVariadic.h"
40 VectorType distributedType) {
46 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
47 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
51 distributedType.getContext());
62 struct DistributedLoadStoreHelper {
63 DistributedLoadStoreHelper(
Value sequentialVal,
Value distributedVal,
65 : sequentialVal(sequentialVal), distributedVal(distributedVal),
66 laneId(laneId), zero(zero) {
67 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.
getType());
68 distributedVectorType = dyn_cast<VectorType>(distributedVal.
getType());
69 if (sequentialVectorType && distributedVectorType)
75 int64_t distributedSize = distributedVectorType.getDimSize(index);
77 return b.
createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
90 assert((val == distributedVal || val == sequentialVal) &&
91 "Must store either the preregistered distributed or the "
92 "preregistered sequential value.");
94 if (!isa<VectorType>(val.
getType()))
95 return b.
create<memref::StoreOp>(loc, val, buffer, zero);
99 int64_t rank = sequentialVectorType.getRank();
101 if (val == distributedVal) {
102 for (
auto dimExpr : distributionMap.getResults()) {
103 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
104 indices[index] = buildDistributedOffset(b, loc, index);
108 return b.
create<vector::TransferWriteOp>(
109 loc, val, buffer, indices,
136 if (!isa<VectorType>(type))
137 return b.
create<memref::LoadOp>(loc, buffer, zero);
142 assert((type == distributedVectorType || type == sequentialVectorType) &&
143 "Must store either the preregistered distributed or the "
144 "preregistered sequential type.");
146 if (type == distributedVectorType) {
147 for (
auto dimExpr : distributionMap.getResults()) {
148 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
149 indices[index] = buildDistributedOffset(b, loc, index);
153 return b.
create<vector::TransferReadOp>(
154 loc, cast<VectorType>(type), buffer, indices,
158 Value sequentialVal, distributedVal, laneId, zero;
159 VectorType sequentialVectorType, distributedVectorType;
173 return rewriter.
create(res);
212 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
214 assert(warpOp.getBodyRegion().hasOneBlock() &&
215 "expected WarpOp with single block");
216 Block *warpOpBody = &warpOp.getBodyRegion().
front();
224 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
226 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
227 auto ifOp = rewriter.
create<scf::IfOp>(loc, isLane0,
229 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
236 Value distributedVal = it.value();
237 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
238 warpOp.getLaneid(), c0);
242 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
245 helper.buildStore(rewriter, loc, distributedVal, buffer);
248 bbArgReplacements.push_back(
249 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
253 if (!warpOp.getArgs().empty()) {
255 options.warpSyncronizationFn(loc, rewriter, warpOp);
259 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
266 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
267 Location yieldLoc = yieldOp.getLoc();
269 Value sequentialVal = it.value();
270 Value distributedVal = warpOp->getResult(it.index());
271 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
272 warpOp.getLaneid(), c0);
276 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
282 helper.buildStore(rewriter, loc, sequentialVal, buffer);
293 replacements.push_back(
294 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
298 if (!yieldOp.getOperands().empty()) {
300 options.warpSyncronizationFn(loc, rewriter, warpOp);
306 rewriter.
create<scf::YieldOp>(yieldLoc);
309 rewriter.
replaceOp(warpOp, replacements);
326 static VectorType getDistributedType(VectorType originalType,
AffineMap map,
331 if (targetShape[position] % warpSize != 0) {
332 if (warpSize % targetShape[position] != 0) {
335 warpSize /= targetShape[position];
336 targetShape[position] = 1;
339 targetShape[position] = targetShape[position] / warpSize;
346 VectorType targetType =
374 maxNumElementsToExtract(maxNumElementsToExtract) {}
379 vector::TransferWriteOp writeOp,
380 WarpExecuteOnLane0Op warpOp)
const {
381 VectorType writtenVectorType = writeOp.getVectorType();
385 if (writtenVectorType.getRank() == 0)
389 AffineMap map = distributionMapFn(writeOp.getVector());
390 VectorType targetType =
391 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
397 if (writeOp.getMask()) {
404 if (!writeOp.getPermutationMap().isMinorIdentity())
407 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
412 vector::TransferWriteOp newWriteOp =
413 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
417 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
424 for (
auto [seqSize, distSize] :
425 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
426 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
427 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
431 delinearized = rewriter
432 .
create<mlir::affine::AffineDelinearizeIndexOp>(
433 newWarpOp.getLoc(), newWarpOp.getLaneid(),
439 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
445 newWriteOp.getIndices().end());
448 bindDims(newWarpOp.getContext(), d0, d1);
449 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
452 unsigned indexPos = indexExpr.getPosition();
453 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
454 Value laneId = delinearized[vectorPos];
458 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
460 newWriteOp.getIndicesMutable().assign(indices);
467 vector::TransferWriteOp writeOp,
468 WarpExecuteOnLane0Op warpOp)
const {
470 VectorType vecType = writeOp.getVectorType();
472 if (vecType.getNumElements() > maxNumElementsToExtract) {
476 "writes more elements ({0}) than allowed to extract ({1})",
477 vecType.getNumElements(), maxNumElementsToExtract));
481 if (llvm::all_of(warpOp.getOps(),
482 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
488 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
489 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
493 auto secondWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
494 loc,
TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
495 Block &body = secondWarpOp.getBodyRegion().
front();
498 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
499 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
501 rewriter.
create<gpu::YieldOp>(newWarpOp.getLoc());
505 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
507 auto yield = cast<gpu::YieldOp>(
508 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
509 Operation *lastNode = yield->getPrevNode();
510 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
514 Value maybeMask = writeOp.getMask();
515 if (!llvm::all_of(writeOp->getOperands(), [&](
Value value) {
516 return writeOp.getVector() == value ||
517 (maybeMask && maybeMask == value) ||
518 warpOp.isDefinedOutsideOfRegion(value);
522 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
526 if (writeOp.getMask())
529 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
539 vector::TransferWriteOp cloneWriteOp(
RewriterBase &rewriter,
540 WarpExecuteOnLane0Op warpOp,
541 vector::TransferWriteOp writeOp,
542 VectorType targetType,
543 VectorType maybeMaskType)
const {
544 assert(writeOp->getParentOp() == warpOp &&
545 "write must be nested immediately under warp");
548 WarpExecuteOnLane0Op newWarpOp;
550 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
551 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
552 TypeRange{targetType, maybeMaskType}, newRetIndices);
554 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
555 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
560 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
562 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
564 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
569 unsigned maxNumElementsToExtract = 1;
592 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
602 Value distributedVal = warpOp.getResult(operandIndex);
608 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
610 auto operandType = cast<VectorType>(operand.get().getType());
614 auto operandType = operand.get().getType();
615 assert(!isa<VectorType>(operandType) &&
616 "unexpected yield of vector from op with scalar result type");
617 targetType = operandType;
619 retTypes.push_back(targetType);
620 yieldValues.push_back(operand.get());
623 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
624 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
628 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
629 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
634 rewriter, loc, elementWise, newOperands,
635 {newWarpOp.getResult(operandIndex).
getType()});
658 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
661 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
665 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
674 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
677 Value distConstant = rewriter.
create<arith::ConstantOp>(loc, newAttr);
704 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
711 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
715 warpOp,
"warp result is not a vector.transfer_read op");
719 if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
721 read,
"source must be defined outside of the region");
724 Value distributedVal = warpOp.getResult(operandIndex);
727 read.getIndices().end());
728 auto sequentialType = cast<VectorType>(read.getResult().getType());
729 auto distributedType = cast<VectorType>(distributedVal.
getType());
736 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
737 distributedType.getShape(), warpOp.getWarpSize(),
738 warpOp.getLaneid(), delinearizedIds)) {
740 read,
"cannot delinearize lane ID for distribution");
742 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
749 additionalResults.push_back(read.getPadding());
750 additionalResultTypes.push_back(read.getPadding().getType());
752 bool hasMask =
false;
753 if (read.getMask()) {
763 read,
"non-trivial permutation maps not supported");
764 VectorType maskType =
765 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
766 additionalResults.push_back(read.getMask());
767 additionalResultTypes.push_back(maskType);
771 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
772 rewriter, warpOp, additionalResults, additionalResultTypes,
774 distributedVal = newWarpOp.getResult(operandIndex);
778 for (int64_t i = 0, e = indices.size(); i < e; ++i)
779 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
784 bindDims(read.getContext(), d0, d1);
785 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
788 unsigned indexPos = indexExpr.getPosition();
789 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
790 int64_t scale = distributedType.getDimSize(vectorPos);
792 rewriter, read.getLoc(), d0 + scale * d1,
793 {newIndices[indexPos], delinearizedIds[vectorPos]});
797 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
800 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
802 auto newRead = rewriter.
create<vector::TransferReadOp>(
803 read.getLoc(), distributedVal.
getType(), read.getSource(), newIndices,
804 read.getPermutationMapAttr(), newPadding, newMask,
805 read.getInBoundsAttr());
816 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
819 newResultTypes.reserve(warpOp->getNumResults());
821 newYieldValues.reserve(warpOp->getNumResults());
824 auto yield = cast<gpu::YieldOp>(
825 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
835 for (
OpResult result : warpOp.getResults()) {
836 Value yieldOperand = yield.getOperand(result.getResultNumber());
837 auto it = dedupYieldOperandPositionMap.insert(
838 std::make_pair(yieldOperand, newResultTypes.size()));
839 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
840 if (result.use_empty() || !it.second)
842 newResultTypes.push_back(result.getType());
843 newYieldValues.push_back(yieldOperand);
846 if (yield.getNumOperands() == newYieldValues.size())
849 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
850 rewriter, warpOp, newYieldValues, newResultTypes);
853 newWarpOp.getBody()->walk([&](
Operation *op) {
860 newValues.reserve(warpOp->getNumResults());
861 for (
OpResult result : warpOp.getResults()) {
862 if (result.use_empty())
863 newValues.push_back(
Value());
866 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
877 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
881 auto yield = cast<gpu::YieldOp>(
882 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
884 unsigned resultIndex;
885 for (
OpOperand &operand : yield->getOpOperands()) {
894 valForwarded = operand.
get();
898 auto arg = dyn_cast<BlockArgument>(operand.
get());
899 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
901 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
904 valForwarded = warpOperand;
921 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
924 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
929 Location loc = broadcastOp.getLoc();
931 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
932 Value broadcastSrc = broadcastOp.getSource();
943 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
944 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
946 Value broadcasted = rewriter.
create<vector::BroadcastOp>(
947 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
958 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
961 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
968 auto castDistributedType =
969 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
970 VectorType castOriginalType = oldCastOp.getSourceVectorType();
971 VectorType castResultType = castDistributedType;
975 unsigned castDistributedRank = castDistributedType.getRank();
976 unsigned castOriginalRank = castOriginalType.getRank();
977 if (castDistributedRank < castOriginalRank) {
979 llvm::append_range(shape, castDistributedType.getShape());
980 castDistributedType =
985 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
986 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
989 Value newCast = rewriter.
create<vector::ShapeCastOp>(
990 oldCastOp.getLoc(), castResultType,
991 newWarpOp->getResult(newRetIndices[0]));
1017 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1020 getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1028 if (!llvm::all_of(mask->getOperands(), [&](
Value value) {
1029 return warpOp.isDefinedOutsideOfRegion(value);
1036 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1037 VectorType seqType = mask.getVectorType();
1045 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1046 warpOp.getWarpSize(), warpOp.getLaneid(),
1049 mask,
"cannot delinearize lane ID for distribution");
1050 assert(!delinearizedIds.empty());
1059 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1066 rewriter, loc, s1 - s0 * distShape[i],
1067 {delinearizedIds[i], mask.getOperand(i)});
1068 newOperands.push_back(maskDimIdx);
1072 rewriter.
create<vector::CreateMaskOp>(loc, distType, newOperands);
1083 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1086 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1091 VectorType extractSrcType = extractOp.getSourceVectorType();
1095 if (extractSrcType.getRank() <= 1) {
1101 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1108 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1109 rewriter, warpOp, {extractOp.getVector()},
1110 {extractOp.getSourceVectorType()}, newRetIndices);
1112 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1114 Value newExtract = rewriter.
create<vector::ExtractOp>(
1115 loc, distributedVec, extractOp.getMixedPosition());
1122 auto distributedType =
1123 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1124 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1125 int64_t distributedDim = -1;
1126 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1127 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1130 assert(distributedDim == -1 &&
"found multiple distributed dims");
1134 assert(distributedDim != -1 &&
"could not find distributed dimension");
1135 (void)distributedDim;
1139 for (
int i = 0; i < distributedType.getRank(); ++i)
1140 newDistributedShape[i + extractOp.getNumIndices()] =
1141 distributedType.getDimSize(i);
1142 auto newDistributedType =
1143 VectorType::get(newDistributedShape, distributedType.getElementType());
1145 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1146 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1149 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1151 Value newExtract = rewriter.
create<vector::ExtractOp>(
1152 loc, distributedVec, extractOp.getMixedPosition());
1162 WarpOpExtractScalar(
MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1165 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1168 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1173 VectorType extractSrcType = extractOp.getSourceVectorType();
1175 if (extractSrcType.getRank() > 1) {
1177 extractOp,
"only 0-D or 1-D source supported for now");
1181 if (!extractSrcType.getElementType().isF32() &&
1182 !extractSrcType.getElementType().isInteger(32))
1184 extractOp,
"only f32/i32 element types are supported");
1185 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1186 Type elType = extractSrcType.getElementType();
1187 VectorType distributedVecType;
1188 if (!is0dOrVec1Extract) {
1189 assert(extractSrcType.getRank() == 1 &&
1190 "expected that extract src rank is 0 or 1");
1191 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1193 int64_t elementsPerLane =
1194 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1197 distributedVecType = extractSrcType;
1202 additionalResults.append(
1204 additionalResultTypes.append(
1209 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1210 rewriter, warpOp, additionalResults, additionalResultTypes,
1213 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1217 if (is0dOrVec1Extract) {
1221 rewriter.
create<vector::ExtractOp>(loc, distributedVec, indices);
1227 int64_t staticPos = extractOp.getStaticPosition()[0];
1229 ? (newWarpOp->getResult(newRetIndices[1]))
1233 int64_t elementsPerLane = distributedVecType.getShape()[0];
1237 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1240 elementsPerLane == 1
1241 ? rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult()
1243 sym0 % elementsPerLane, pos);
1245 rewriter.
create<vector::ExtractOp>(loc, distributedVec, newPos);
1248 Value shuffled = warpShuffleFromIdxFn(
1249 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1255 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1261 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1264 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1269 if (
auto pos = extractOp.getPosition()) {
1270 indices.push_back(pos);
1274 extractOp, extractOp.getVector(), indices);
1283 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1285 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1290 VectorType vecType = insertOp.getDestVectorType();
1291 VectorType distrType =
1292 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1295 if (vecType.getRank() > 1) {
1297 insertOp,
"only 0-D or 1-D source supported for now");
1302 insertOp.getSource()};
1304 insertOp.getSource().getType()};
1306 additionalResultTypes.append(
1311 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1312 rewriter, warpOp, additionalResults, additionalResultTypes,
1315 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1316 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1320 if (vecType.getRank() != 0) {
1321 int64_t staticPos = insertOp.getStaticPosition()[0];
1322 pos = ShapedType::isDynamic(staticPos)
1323 ? (newWarpOp->getResult(newRetIndices[2]))
1328 if (vecType == distrType) {
1332 indices.push_back(pos);
1334 newInsert = rewriter.
create<vector::InsertOp>(loc, newSource,
1335 distributedVec, indices);
1343 int64_t elementsPerLane = distrType.getShape()[0];
1347 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1350 rewriter, loc, sym0 % elementsPerLane, pos);
1351 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1352 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1356 loc, isInsertingLane,
1359 Value newInsert = builder.create<vector::InsertOp>(
1360 loc, newSource, distributedVec, newPos);
1361 builder.create<scf::YieldOp>(loc, newInsert);
1365 builder.create<scf::YieldOp>(loc, distributedVec);
1375 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1377 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1385 if (insertOp.getDestVectorType().getRank() <= 1) {
1391 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1395 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1396 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1397 {insertOp.getSourceType(), insertOp.getDestVectorType()},
1400 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1401 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1402 Value newResult = rewriter.
create<vector::InsertOp>(
1403 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1410 auto distrDestType =
1411 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1412 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1413 int64_t distrDestDim = -1;
1414 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1415 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1418 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1422 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1425 VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
1433 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1434 if (distrSrcDim >= 0)
1435 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1441 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1442 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1443 {distrSrcType, distrDestType}, newRetIndices);
1445 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1446 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1450 if (distrSrcDim >= 0) {
1452 newResult = rewriter.
create<vector::InsertOp>(
1453 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1456 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1460 Value insertingLane = rewriter.
create<arith::ConstantIndexOp>(
1461 loc, newPos[distrDestDim] / elementsPerLane);
1462 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1463 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1465 newPos[distrDestDim] %= elementsPerLane;
1467 Value newInsert = builder.
create<vector::InsertOp>(
1468 loc, distributedSrc, distributedDest, newPos);
1469 builder.
create<scf::YieldOp>(loc, newInsert);
1472 builder.
create<scf::YieldOp>(loc, distributedDest);
1474 newResult = rewriter
1475 .
create<scf::IfOp>(loc, isInsertingLane,
1477 nonInsertingBuilder)
1488 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1491 getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1496 if (
auto pos = insertOp.getPosition()) {
1497 indices.push_back(pos);
1501 insertOp, insertOp.getSource(), insertOp.getDest(), indices);
1542 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1544 auto yield = cast<gpu::YieldOp>(
1545 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1547 Operation *lastNode = yield->getPrevNode();
1548 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1554 llvm::SmallSetVector<Value, 32> escapingValues;
1558 forOp.getBodyRegion(), [&](
OpOperand *operand) {
1559 Operation *parent = operand->get().getParentRegion()->getParentOp();
1560 if (warpOp->isAncestor(parent)) {
1561 if (!escapingValues.insert(operand->get()))
1563 Type distType = operand->get().getType();
1564 if (auto vecType = dyn_cast<VectorType>(distType)) {
1565 AffineMap map = distributionMapFn(operand->get());
1566 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1568 inputTypes.push_back(operand->get().getType());
1569 distTypes.push_back(distType);
1573 if (llvm::is_contained(distTypes,
Type{}))
1577 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1578 rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1580 yield = cast<gpu::YieldOp>(
1581 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1586 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
1589 auto forResult = cast<OpResult>(yieldOperand.
get());
1590 newOperands.push_back(
1592 yieldOperand.
set(forOp.getInitArgs()[forResult.getResultNumber()]);
1601 auto newForOp = rewriter.
create<scf::ForOp>(
1602 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1603 forOp.getStep(), newOperands);
1607 newForOp.getRegionIterArgs().end());
1609 forOp.getResultTypes().end());
1610 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1612 warpInput.push_back(newWarpOp.getResult(retIdx));
1613 argIndexMapping[escapingValues[i]] = warpInputType.size();
1614 warpInputType.push_back(inputTypes[i]);
1616 auto innerWarp = rewriter.
create<WarpExecuteOnLane0Op>(
1617 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1618 newWarpOp.getWarpSize(), warpInput, warpInputType);
1621 argMapping.push_back(newForOp.getInductionVar());
1622 for (
Value args : innerWarp.getBody()->getArguments()) {
1623 argMapping.push_back(args);
1625 argMapping.resize(forOp.getBody()->getNumArguments());
1627 for (
Value operand : forOp.getBody()->getTerminator()->getOperands())
1628 yieldOperands.push_back(operand);
1629 rewriter.
eraseOp(forOp.getBody()->getTerminator());
1630 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1632 rewriter.
create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
1634 if (!innerWarp.getResults().empty())
1635 rewriter.
create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1640 newForOp.getResult(res.index()));
1641 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1645 auto it = argIndexMapping.find(operand.
get());
1646 if (it == argIndexMapping.end())
1648 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
1653 mlir::vector::moveScalarUniformCode(innerWarp);
1682 DistributedReductionFn distributedReductionFn,
1685 distributedReductionFn(std::move(distributedReductionFn)) {}
1687 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1690 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1696 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1698 if (vectorType.getRank() != 1)
1700 warpOp,
"Only rank 1 reductions can be distributed.");
1702 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1704 warpOp,
"Reduction vector dimension must match was size.");
1705 if (!reductionOp.getType().isIntOrFloat())
1707 warpOp,
"Reduction distribution currently only supports floats and "
1710 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1716 if (reductionOp.getAcc()) {
1717 yieldValues.push_back(reductionOp.getAcc());
1718 retTypes.push_back(reductionOp.getAcc().getType());
1721 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1722 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1726 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1729 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1730 reductionOp.getKind(), newWarpOp.getWarpSize());
1731 if (reductionOp.getAcc()) {
1733 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1734 newWarpOp.getResult(newRetIndices[1]));
1741 DistributedReductionFn distributedReductionFn;
1752 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1755 patterns.add<WarpOpTransferWrite>(
patterns.getContext(), distributionMapFn,
1756 maxNumElementsToExtract, benefit);
1759 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1761 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
1764 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1765 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1766 WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1767 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1769 patterns.add<WarpOpExtractScalar>(
patterns.getContext(), warpShuffleFromIdxFn,
1775 void mlir::vector::populateDistributeReduction(
1777 const DistributedReductionFn &distributedReductionFn,
1779 patterns.add<WarpOpReduction>(
patterns.getContext(), distributedReductionFn,
1786 return llvm::all_of(op->
getOperands(), definedOutside) &&
1790 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1791 Block *body = warpOp.getBody();
1794 llvm::SmallSetVector<Operation *, 8> opsToMove;
1797 auto isDefinedOutsideOfBody = [&](
Value value) {
1799 return (definingOp && opsToMove.count(definingOp)) ||
1800 warpOp.isDefinedOutsideOfRegion(value);
1806 bool hasVectorResult = llvm::any_of(op.
getResults(), [](
Value result) {
1807 return isa<VectorType>(result.getType());
1809 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
1810 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.