20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/FormatVariadic.h"
40 VectorType distributedType) {
46 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
47 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
51 distributedType.getContext());
62 struct DistributedLoadStoreHelper {
63 DistributedLoadStoreHelper(
Value sequentialVal,
Value distributedVal,
65 : sequentialVal(sequentialVal), distributedVal(distributedVal),
66 laneId(laneId), zero(zero) {
67 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.
getType());
68 distributedVectorType = dyn_cast<VectorType>(distributedVal.
getType());
69 if (sequentialVectorType && distributedVectorType)
75 int64_t distributedSize = distributedVectorType.getDimSize(index);
77 return b.
createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
90 assert((val == distributedVal || val == sequentialVal) &&
91 "Must store either the preregistered distributed or the "
92 "preregistered sequential value.");
94 if (!isa<VectorType>(val.
getType()))
95 return b.
create<memref::StoreOp>(loc, val, buffer, zero);
99 int64_t rank = sequentialVectorType.getRank();
101 if (val == distributedVal) {
102 for (
auto dimExpr : distributionMap.getResults()) {
103 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
104 indices[index] = buildDistributedOffset(b, loc, index);
108 return b.
create<vector::TransferWriteOp>(
109 loc, val, buffer, indices,
136 if (!isa<VectorType>(type))
137 return b.
create<memref::LoadOp>(loc, buffer, zero);
142 assert((type == distributedVectorType || type == sequentialVectorType) &&
143 "Must store either the preregistered distributed or the "
144 "preregistered sequential type.");
146 if (type == distributedVectorType) {
147 for (
auto dimExpr : distributionMap.getResults()) {
148 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
149 indices[index] = buildDistributedOffset(b, loc, index);
153 return b.
create<vector::TransferReadOp>(
154 loc, cast<VectorType>(type), buffer, indices,
158 Value sequentialVal, distributedVal, laneId, zero;
159 VectorType sequentialVectorType, distributedVectorType;
173 return rewriter.
create(res);
212 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
214 assert(warpOp.getBodyRegion().hasOneBlock() &&
215 "expected WarpOp with single block");
216 Block *warpOpBody = &warpOp.getBodyRegion().
front();
224 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
226 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
227 auto ifOp = rewriter.
create<scf::IfOp>(loc, isLane0,
229 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
236 Value distributedVal = it.value();
237 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
238 warpOp.getLaneid(), c0);
242 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
245 helper.buildStore(rewriter, loc, distributedVal, buffer);
248 bbArgReplacements.push_back(
249 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
253 if (!warpOp.getArgs().empty()) {
255 options.warpSyncronizationFn(loc, rewriter, warpOp);
259 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
266 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
267 Location yieldLoc = yieldOp.getLoc();
269 Value sequentialVal = it.value();
270 Value distributedVal = warpOp->getResult(it.index());
271 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
272 warpOp.getLaneid(), c0);
276 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
282 helper.buildStore(rewriter, loc, sequentialVal, buffer);
293 replacements.push_back(
294 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
298 if (!yieldOp.getOperands().empty()) {
300 options.warpSyncronizationFn(loc, rewriter, warpOp);
306 rewriter.
create<scf::YieldOp>(yieldLoc);
309 rewriter.
replaceOp(warpOp, replacements);
326 static VectorType getDistributedType(VectorType originalType,
AffineMap map,
331 if (targetShape[position] % warpSize != 0) {
332 if (warpSize % targetShape[position] != 0) {
335 warpSize /= targetShape[position];
336 targetShape[position] = 1;
339 targetShape[position] = targetShape[position] / warpSize;
346 VectorType targetType =
374 maxNumElementsToExtract(maxNumElementsToExtract) {}
379 vector::TransferWriteOp writeOp,
380 WarpExecuteOnLane0Op warpOp)
const {
381 VectorType writtenVectorType = writeOp.getVectorType();
385 if (writtenVectorType.getRank() == 0)
389 AffineMap map = distributionMapFn(writeOp.getVector());
390 VectorType targetType =
391 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
397 if (writeOp.getMask()) {
404 if (!writeOp.getPermutationMap().isMinorIdentity())
407 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
412 vector::TransferWriteOp newWriteOp =
413 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
417 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
424 for (
auto [seqSize, distSize] :
425 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
426 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
427 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
431 delinearized = rewriter
432 .
create<mlir::affine::AffineDelinearizeIndexOp>(
433 newWarpOp.getLoc(), newWarpOp.getLaneid(),
439 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
445 newWriteOp.getIndices().end());
448 bindDims(newWarpOp.getContext(), d0, d1);
449 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
452 unsigned indexPos = indexExpr.getPosition();
453 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
454 Value laneId = delinearized[vectorPos];
458 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
460 newWriteOp.getIndicesMutable().assign(indices);
467 vector::TransferWriteOp writeOp,
468 WarpExecuteOnLane0Op warpOp)
const {
470 VectorType vecType = writeOp.getVectorType();
472 if (vecType.getNumElements() > maxNumElementsToExtract) {
476 "writes more elements ({0}) than allowed to extract ({1})",
477 vecType.getNumElements(), maxNumElementsToExtract));
481 if (llvm::all_of(warpOp.getOps(),
482 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
488 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
489 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
493 auto secondWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
494 loc,
TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
495 Block &body = secondWarpOp.getBodyRegion().
front();
498 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
499 newWriteOp.getValueToStoreMutable().assign(
500 newWarpOp.getResult(newRetIndices[0]));
502 rewriter.
create<gpu::YieldOp>(newWarpOp.getLoc());
506 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
508 auto yield = cast<gpu::YieldOp>(
509 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
510 Operation *lastNode = yield->getPrevNode();
511 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
515 Value maybeMask = writeOp.getMask();
516 if (!llvm::all_of(writeOp->getOperands(), [&](
Value value) {
517 return writeOp.getVector() == value ||
518 (maybeMask && maybeMask == value) ||
519 warpOp.isDefinedOutsideOfRegion(value);
523 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
527 if (writeOp.getMask())
530 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
540 vector::TransferWriteOp cloneWriteOp(
RewriterBase &rewriter,
541 WarpExecuteOnLane0Op warpOp,
542 vector::TransferWriteOp writeOp,
543 VectorType targetType,
544 VectorType maybeMaskType)
const {
545 assert(writeOp->getParentOp() == warpOp &&
546 "write must be nested immediately under warp");
549 WarpExecuteOnLane0Op newWarpOp;
551 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
552 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
553 TypeRange{targetType, maybeMaskType}, newRetIndices);
555 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
556 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
561 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
563 newWriteOp.getValueToStoreMutable().assign(
564 newWarpOp.getResult(newRetIndices[0]));
566 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
571 unsigned maxNumElementsToExtract = 1;
594 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
604 Value distributedVal = warpOp.getResult(operandIndex);
610 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
612 auto operandType = cast<VectorType>(operand.get().getType());
616 auto operandType = operand.get().getType();
617 assert(!isa<VectorType>(operandType) &&
618 "unexpected yield of vector from op with scalar result type");
619 targetType = operandType;
621 retTypes.push_back(targetType);
622 yieldValues.push_back(operand.get());
625 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
626 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
630 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
631 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
636 rewriter, loc, elementWise, newOperands,
637 {newWarpOp.getResult(operandIndex).
getType()});
660 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
663 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
667 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
676 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
679 Value distConstant = rewriter.
create<arith::ConstantOp>(loc, newAttr);
706 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
713 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
717 warpOp,
"warp result is not a vector.transfer_read op");
721 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
723 read,
"source must be defined outside of the region");
726 Value distributedVal = warpOp.getResult(operandIndex);
729 read.getIndices().end());
730 auto sequentialType = cast<VectorType>(read.getResult().getType());
731 auto distributedType = cast<VectorType>(distributedVal.
getType());
738 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
739 distributedType.getShape(), warpOp.getWarpSize(),
740 warpOp.getLaneid(), delinearizedIds)) {
742 read,
"cannot delinearize lane ID for distribution");
744 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
751 additionalResults.push_back(read.getPadding());
752 additionalResultTypes.push_back(read.getPadding().getType());
754 bool hasMask =
false;
755 if (read.getMask()) {
765 read,
"non-trivial permutation maps not supported");
766 VectorType maskType =
767 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
768 additionalResults.push_back(read.getMask());
769 additionalResultTypes.push_back(maskType);
773 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
774 rewriter, warpOp, additionalResults, additionalResultTypes,
776 distributedVal = newWarpOp.getResult(operandIndex);
780 for (int64_t i = 0, e = indices.size(); i < e; ++i)
781 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
786 bindDims(read.getContext(), d0, d1);
787 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
790 unsigned indexPos = indexExpr.getPosition();
791 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
792 int64_t scale = distributedType.getDimSize(vectorPos);
794 rewriter, read.getLoc(), d0 + scale * d1,
795 {newIndices[indexPos], delinearizedIds[vectorPos]});
799 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
802 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
804 auto newRead = rewriter.
create<vector::TransferReadOp>(
805 read.getLoc(), distributedVal.
getType(), read.getBase(), newIndices,
806 read.getPermutationMapAttr(), newPadding, newMask,
807 read.getInBoundsAttr());
818 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
821 newResultTypes.reserve(warpOp->getNumResults());
823 newYieldValues.reserve(warpOp->getNumResults());
826 auto yield = cast<gpu::YieldOp>(
827 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
837 for (
OpResult result : warpOp.getResults()) {
838 Value yieldOperand = yield.getOperand(result.getResultNumber());
839 auto it = dedupYieldOperandPositionMap.insert(
840 std::make_pair(yieldOperand, newResultTypes.size()));
841 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
842 if (result.use_empty() || !it.second)
844 newResultTypes.push_back(result.getType());
845 newYieldValues.push_back(yieldOperand);
848 if (yield.getNumOperands() == newYieldValues.size())
851 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
852 rewriter, warpOp, newYieldValues, newResultTypes);
855 newWarpOp.getBody()->walk([&](
Operation *op) {
862 newValues.reserve(warpOp->getNumResults());
863 for (
OpResult result : warpOp.getResults()) {
864 if (result.use_empty())
865 newValues.push_back(
Value());
868 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
879 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
881 auto yield = cast<gpu::YieldOp>(
882 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
884 unsigned resultIndex;
885 for (
OpOperand &operand : yield->getOpOperands()) {
894 valForwarded = operand.
get();
898 auto arg = dyn_cast<BlockArgument>(operand.
get());
899 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
901 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
904 valForwarded = warpOperand;
921 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
924 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
929 Location loc = broadcastOp.getLoc();
931 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
932 Value broadcastSrc = broadcastOp.getSource();
943 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
944 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
946 Value broadcasted = rewriter.
create<vector::BroadcastOp>(
947 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
958 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
961 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
968 auto castDistributedType =
969 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
970 VectorType castOriginalType = oldCastOp.getSourceVectorType();
971 VectorType castResultType = castDistributedType;
975 unsigned castDistributedRank = castDistributedType.getRank();
976 unsigned castOriginalRank = castOriginalType.getRank();
977 if (castDistributedRank < castOriginalRank) {
979 llvm::append_range(shape, castDistributedType.getShape());
980 castDistributedType =
985 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
986 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
989 Value newCast = rewriter.
create<vector::ShapeCastOp>(
990 oldCastOp.getLoc(), castResultType,
991 newWarpOp->getResult(newRetIndices[0]));
1017 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1020 getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1028 if (!llvm::all_of(mask->getOperands(), [&](
Value value) {
1029 return warpOp.isDefinedOutsideOfRegion(value);
1036 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1037 VectorType seqType = mask.getVectorType();
1045 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1046 warpOp.getWarpSize(), warpOp.getLaneid(),
1049 mask,
"cannot delinearize lane ID for distribution");
1050 assert(!delinearizedIds.empty());
1059 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1066 rewriter, loc, s1 - s0 * distShape[i],
1067 {delinearizedIds[i], mask.getOperand(i)});
1068 newOperands.push_back(maskDimIdx);
1072 rewriter.
create<vector::CreateMaskOp>(loc, distType, newOperands);
1083 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1086 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1091 VectorType extractSrcType = extractOp.getSourceVectorType();
1095 if (extractSrcType.getRank() <= 1) {
1101 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1108 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1109 rewriter, warpOp, {extractOp.getVector()},
1110 {extractOp.getSourceVectorType()}, newRetIndices);
1112 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1114 Value newExtract = rewriter.
create<vector::ExtractOp>(
1115 loc, distributedVec, extractOp.getMixedPosition());
1122 auto distributedType =
1123 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1124 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1125 int64_t distributedDim = -1;
1126 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1127 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1130 assert(distributedDim == -1 &&
"found multiple distributed dims");
1134 assert(distributedDim != -1 &&
"could not find distributed dimension");
1135 (void)distributedDim;
1139 for (
int i = 0; i < distributedType.getRank(); ++i)
1140 newDistributedShape[i + extractOp.getNumIndices()] =
1141 distributedType.getDimSize(i);
1142 auto newDistributedType =
1143 VectorType::get(newDistributedShape, distributedType.getElementType());
1145 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1146 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1149 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1151 Value newExtract = rewriter.
create<vector::ExtractOp>(
1152 loc, distributedVec, extractOp.getMixedPosition());
1162 WarpOpExtractScalar(
MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1165 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1168 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1173 VectorType extractSrcType = extractOp.getSourceVectorType();
1175 if (extractSrcType.getRank() > 1) {
1177 extractOp,
"only 0-D or 1-D source supported for now");
1181 if (!extractSrcType.getElementType().isF32() &&
1182 !extractSrcType.getElementType().isInteger(32))
1184 extractOp,
"only f32/i32 element types are supported");
1185 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1186 Type elType = extractSrcType.getElementType();
1187 VectorType distributedVecType;
1188 if (!is0dOrVec1Extract) {
1189 assert(extractSrcType.getRank() == 1 &&
1190 "expected that extract src rank is 0 or 1");
1191 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1193 int64_t elementsPerLane =
1194 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1197 distributedVecType = extractSrcType;
1202 additionalResults.append(
1204 additionalResultTypes.append(
1209 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1210 rewriter, warpOp, additionalResults, additionalResultTypes,
1213 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1217 if (is0dOrVec1Extract) {
1221 rewriter.
create<vector::ExtractOp>(loc, distributedVec, indices);
1227 int64_t staticPos = extractOp.getStaticPosition()[0];
1229 ? (newWarpOp->getResult(newRetIndices[1]))
1233 int64_t elementsPerLane = distributedVecType.getShape()[0];
1237 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1240 elementsPerLane == 1
1241 ? rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult()
1243 sym0 % elementsPerLane, pos);
1245 rewriter.
create<vector::ExtractOp>(loc, distributedVec, newPos);
1248 Value shuffled = warpShuffleFromIdxFn(
1249 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1255 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1261 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1264 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1269 if (
auto pos = extractOp.getPosition()) {
1270 indices.push_back(pos);
1274 extractOp, extractOp.getVector(), indices);
1283 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1285 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1290 VectorType vecType = insertOp.getDestVectorType();
1291 VectorType distrType =
1292 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1295 if (vecType.getRank() > 1) {
1297 insertOp,
"only 0-D or 1-D source supported for now");
1302 insertOp.getValueToStore()};
1304 distrType, insertOp.getValueToStore().getType()};
1306 additionalResultTypes.append(
1311 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1312 rewriter, warpOp, additionalResults, additionalResultTypes,
1315 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1316 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1320 if (vecType.getRank() != 0) {
1321 int64_t staticPos = insertOp.getStaticPosition()[0];
1322 pos = ShapedType::isDynamic(staticPos)
1323 ? (newWarpOp->getResult(newRetIndices[2]))
1328 if (vecType == distrType) {
1332 indices.push_back(pos);
1334 newInsert = rewriter.
create<vector::InsertOp>(loc, newSource,
1335 distributedVec, indices);
1343 int64_t elementsPerLane = distrType.getShape()[0];
1347 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1350 rewriter, loc, sym0 % elementsPerLane, pos);
1351 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1352 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1356 loc, isInsertingLane,
1359 Value newInsert = builder.create<vector::InsertOp>(
1360 loc, newSource, distributedVec, newPos);
1361 builder.create<scf::YieldOp>(loc, newInsert);
1365 builder.create<scf::YieldOp>(loc, distributedVec);
1375 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1377 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1385 if (insertOp.getDestVectorType().getRank() <= 1) {
1391 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1395 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1396 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1397 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1400 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1401 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1402 Value newResult = rewriter.
create<vector::InsertOp>(
1403 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1410 auto distrDestType =
1411 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1412 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1413 int64_t distrDestDim = -1;
1414 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1415 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1418 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1422 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1425 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1433 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1434 if (distrSrcDim >= 0)
1435 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1441 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1442 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1443 {distrSrcType, distrDestType}, newRetIndices);
1445 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1446 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1450 if (distrSrcDim >= 0) {
1452 newResult = rewriter.
create<vector::InsertOp>(
1453 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1456 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1460 Value insertingLane = rewriter.
create<arith::ConstantIndexOp>(
1461 loc, newPos[distrDestDim] / elementsPerLane);
1462 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1463 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1465 newPos[distrDestDim] %= elementsPerLane;
1467 Value newInsert = builder.
create<vector::InsertOp>(
1468 loc, distributedSrc, distributedDest, newPos);
1469 builder.
create<scf::YieldOp>(loc, newInsert);
1472 builder.
create<scf::YieldOp>(loc, distributedDest);
1474 newResult = rewriter
1475 .
create<scf::IfOp>(loc, isInsertingLane,
1477 nonInsertingBuilder)
1488 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1491 getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1496 if (
auto pos = insertOp.getPosition()) {
1497 indices.push_back(pos);
1501 insertOp, insertOp.getSource(), insertOp.getDest(), indices);
1542 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1544 auto yield = cast<gpu::YieldOp>(
1545 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1547 Operation *lastNode = yield->getPrevNode();
1548 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1554 llvm::SmallSetVector<Value, 32> escapingValues;
1557 auto collectEscapingValues = [&](
Value value) {
1558 if (!escapingValues.insert(value))
1561 if (
auto vecType = dyn_cast<VectorType>(distType)) {
1562 AffineMap map = distributionMapFn(value);
1563 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1565 inputTypes.push_back(value.
getType());
1566 distTypes.push_back(distType);
1570 forOp.getBodyRegion(), [&](
OpOperand *operand) {
1571 Operation *parent = operand->get().getParentRegion()->getParentOp();
1572 if (warpOp->isAncestor(parent)) {
1573 collectEscapingValues(operand->get());
1580 for (
OpResult forResult : forOp.getResults()) {
1582 if (llvm::is_contained(yield->getOperands(), forResult))
1584 collectEscapingValues(forResult);
1587 if (llvm::is_contained(distTypes,
Type{}))
1591 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1592 rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1594 yield = cast<gpu::YieldOp>(
1595 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1600 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
1603 auto forResult = cast<OpResult>(yieldOperand.
get());
1604 newOperands.push_back(
1606 yieldOperand.
set(forOp.getInitArgs()[forResult.getResultNumber()]);
1615 auto newForOp = rewriter.
create<scf::ForOp>(
1616 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1617 forOp.getStep(), newOperands);
1621 newForOp.getRegionIterArgs().end());
1623 forOp.getResultTypes().end());
1624 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1626 auto newWarpResult = newWarpOp.getResult(retIdx);
1629 if (llvm::is_contained(newOperands, newWarpResult))
1631 warpInput.push_back(newWarpResult);
1632 argIndexMapping[escapingValues[i]] = warpInputType.size();
1633 warpInputType.push_back(inputTypes[i]);
1635 auto innerWarp = rewriter.
create<WarpExecuteOnLane0Op>(
1636 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1637 newWarpOp.getWarpSize(), warpInput, warpInputType);
1640 argMapping.push_back(newForOp.getInductionVar());
1641 for (
Value args : innerWarp.getBody()->getArguments()) {
1642 argMapping.push_back(args);
1644 argMapping.resize(forOp.getBody()->getNumArguments());
1646 for (
Value operand : forOp.getBody()->getTerminator()->getOperands())
1647 yieldOperands.push_back(operand);
1648 rewriter.
eraseOp(forOp.getBody()->getTerminator());
1649 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1651 rewriter.
create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
1653 if (!innerWarp.getResults().empty())
1654 rewriter.
create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1659 newForOp.getResult(res.index()));
1660 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1664 auto it = argIndexMapping.find(operand.
get());
1665 if (it == argIndexMapping.end())
1667 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
1672 mlir::vector::moveScalarUniformCode(innerWarp);
1701 DistributedReductionFn distributedReductionFn,
1704 distributedReductionFn(std::move(distributedReductionFn)) {}
1706 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1709 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1715 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1717 if (vectorType.getRank() != 1)
1719 warpOp,
"Only rank 1 reductions can be distributed.");
1721 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1723 warpOp,
"Reduction vector dimension must match was size.");
1724 if (!reductionOp.getType().isIntOrFloat())
1726 warpOp,
"Reduction distribution currently only supports floats and "
1729 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1735 if (reductionOp.getAcc()) {
1736 yieldValues.push_back(reductionOp.getAcc());
1737 retTypes.push_back(reductionOp.getAcc().getType());
1740 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1741 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1745 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1748 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1749 reductionOp.getKind(), newWarpOp.getWarpSize());
1750 if (reductionOp.getAcc()) {
1752 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1753 newWarpOp.getResult(newRetIndices[1]));
1760 DistributedReductionFn distributedReductionFn;
1771 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1774 patterns.add<WarpOpTransferWrite>(
patterns.getContext(), distributionMapFn,
1775 maxNumElementsToExtract, benefit);
1778 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1780 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
1783 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1784 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1785 WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1786 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1788 patterns.add<WarpOpExtractScalar>(
patterns.getContext(), warpShuffleFromIdxFn,
1794 void mlir::vector::populateDistributeReduction(
1796 const DistributedReductionFn &distributedReductionFn,
1798 patterns.add<WarpOpReduction>(
patterns.getContext(), distributedReductionFn,
1805 return llvm::all_of(op->
getOperands(), definedOutside) &&
1809 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1810 Block *body = warpOp.getBody();
1813 llvm::SmallSetVector<Operation *, 8> opsToMove;
1816 auto isDefinedOutsideOfBody = [&](
Value value) {
1818 return (definingOp && opsToMove.count(definingOp)) ||
1819 warpOp.isDefinedOutsideOfRegion(value);
1825 bool hasVectorResult = llvm::any_of(op.
getResults(), [](
Value result) {
1826 return isa<VectorType>(result.getType());
1828 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
1829 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.