18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/Support/FormatVariadic.h"
38 VectorType distributedType) {
44 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
45 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
49 distributedType.getContext());
60 struct DistributedLoadStoreHelper {
61 DistributedLoadStoreHelper(
Value sequentialVal,
Value distributedVal,
63 : sequentialVal(sequentialVal), distributedVal(distributedVal),
64 laneId(laneId), zero(zero) {
65 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.
getType());
66 distributedVectorType = dyn_cast<VectorType>(distributedVal.
getType());
67 if (sequentialVectorType && distributedVectorType)
73 int64_t distributedSize = distributedVectorType.getDimSize(index);
75 return b.
createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
88 assert((val == distributedVal || val == sequentialVal) &&
89 "Must store either the preregistered distributed or the "
90 "preregistered sequential value.");
92 if (!isa<VectorType>(val.
getType()))
93 return b.
create<memref::StoreOp>(loc, val, buffer, zero);
97 int64_t rank = sequentialVectorType.getRank();
99 if (val == distributedVal) {
100 for (
auto dimExpr : distributionMap.getResults()) {
101 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
102 indices[index] = buildDistributedOffset(b, loc, index);
106 return b.
create<vector::TransferWriteOp>(
107 loc, val, buffer, indices,
134 if (!isa<VectorType>(type))
135 return b.
create<memref::LoadOp>(loc, buffer, zero);
140 assert((type == distributedVectorType || type == sequentialVectorType) &&
141 "Must store either the preregistered distributed or the "
142 "preregistered sequential type.");
144 if (type == distributedVectorType) {
145 for (
auto dimExpr : distributionMap.getResults()) {
146 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
147 indices[index] = buildDistributedOffset(b, loc, index);
151 return b.
create<vector::TransferReadOp>(
152 loc, cast<VectorType>(type), buffer, indices,
156 Value sequentialVal, distributedVal, laneId, zero;
157 VectorType sequentialVectorType, distributedVectorType;
170 auto newWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
171 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
172 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
174 Region &opBody = warpOp.getBodyRegion();
175 Region &newOpBody = newWarpOp.getBodyRegion();
179 assert(newWarpOp.getWarpRegion().hasOneBlock() &&
180 "expected WarpOp with single block");
183 cast<vector::YieldOp>(newOpBody.
getBlocks().begin()->getTerminator());
186 yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
197 warpOp.getResultTypes().end());
198 auto yield = cast<vector::YieldOp>(
199 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
200 llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
201 yield.getOperands().end());
202 for (
auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
203 if (yieldValues.insert(std::get<0>(newRet))) {
204 types.push_back(std::get<1>(newRet));
205 indices.push_back(yieldValues.size() - 1);
208 for (
auto [idx, yieldOperand] :
210 if (yieldOperand == std::get<0>(newRet)) {
211 indices.push_back(idx);
217 yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
219 rewriter, warpOp, yieldValues.getArrayRef(), types);
221 newWarpOp.getResults().take_front(warpOp.getNumResults()));
228 return llvm::all_of(op->
getOperands(), definedOutside) &&
235 const std::function<
bool(
Operation *)> &fn) {
236 auto yield = cast<vector::YieldOp>(
237 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
238 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
239 Value yieldValues = yieldOperand.get();
241 if (definedOp && fn(definedOp)) {
242 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
243 return &yieldOperand;
257 return rewriter.
create(res);
290 struct WarpOpToScfIfPattern :
public OpRewritePattern<WarpExecuteOnLane0Op> {
297 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
299 assert(warpOp.getBodyRegion().hasOneBlock() &&
300 "expected WarpOp with single block");
301 Block *warpOpBody = &warpOp.getBodyRegion().
front();
309 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
311 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
312 auto ifOp = rewriter.
create<scf::IfOp>(loc, isLane0,
314 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
321 Value distributedVal = it.value();
322 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
323 warpOp.getLaneid(), c0);
327 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
330 helper.buildStore(rewriter, loc, distributedVal, buffer);
333 bbArgReplacements.push_back(
334 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
338 if (!warpOp.getArgs().empty()) {
340 options.warpSyncronizationFn(loc, rewriter, warpOp);
344 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
351 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
352 Location yieldLoc = yieldOp.getLoc();
354 Value sequentialVal = it.value();
355 Value distributedVal = warpOp->getResult(it.index());
356 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
357 warpOp.getLaneid(), c0);
361 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
367 helper.buildStore(rewriter, loc, sequentialVal, buffer);
378 replacements.push_back(
379 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
383 if (!yieldOp.getOperands().empty()) {
385 options.warpSyncronizationFn(loc, rewriter, warpOp);
391 rewriter.
create<scf::YieldOp>(yieldLoc);
394 rewriter.
replaceOp(warpOp, replacements);
407 static vector::TransferWriteOp cloneWriteOp(
RewriterBase &rewriter,
408 WarpExecuteOnLane0Op warpOp,
409 vector::TransferWriteOp writeOp,
410 VectorType targetType,
411 VectorType maybeMaskType) {
412 assert(writeOp->getParentOp() == warpOp &&
413 "write must be nested immediately under warp");
416 WarpExecuteOnLane0Op newWarpOp;
419 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
420 TypeRange{targetType, maybeMaskType}, newRetIndices);
423 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
428 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
430 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
432 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
444 static VectorType getDistributedType(VectorType originalType,
AffineMap map,
449 if (targetShape[position] % warpSize != 0) {
450 if (warpSize % targetShape[position] != 0) {
453 warpSize /= targetShape[position];
454 targetShape[position] = 1;
457 targetShape[position] = targetShape[position] / warpSize;
464 VectorType targetType =
488 struct WarpOpTransferWrite :
public OpRewritePattern<WarpExecuteOnLane0Op> {
492 distributionMapFn(std::move(fn)),
493 maxNumElementsToExtract(maxNumElementsToExtract) {}
498 vector::TransferWriteOp writeOp,
499 WarpExecuteOnLane0Op warpOp)
const {
500 VectorType writtenVectorType = writeOp.getVectorType();
504 if (writtenVectorType.getRank() == 0)
508 AffineMap map = distributionMapFn(writeOp.getVector());
509 VectorType targetType =
510 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
516 if (writeOp.getMask()) {
523 if (!writeOp.getPermutationMap().isMinorIdentity())
526 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
531 vector::TransferWriteOp newWriteOp =
532 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
536 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
543 for (
auto [seqSize, distSize] :
544 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
545 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
546 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
550 delinearized = rewriter
551 .
create<mlir::affine::AffineDelinearizeIndexOp>(
552 newWarpOp.getLoc(), newWarpOp.getLaneid(),
558 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
564 newWriteOp.getIndices().end());
567 bindDims(newWarpOp.getContext(), d0, d1);
568 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
571 unsigned indexPos = indexExpr.getPosition();
572 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
573 Value laneId = delinearized[vectorPos];
577 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
579 newWriteOp.getIndicesMutable().assign(indices);
586 vector::TransferWriteOp writeOp,
587 WarpExecuteOnLane0Op warpOp)
const {
589 VectorType vecType = writeOp.getVectorType();
591 if (vecType.getNumElements() > maxNumElementsToExtract) {
595 "writes more elements ({0}) than allowed to extract ({1})",
596 vecType.getNumElements(), maxNumElementsToExtract));
600 if (llvm::all_of(warpOp.getOps(),
601 llvm::IsaPred<vector::TransferWriteOp, vector::YieldOp>))
608 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
612 auto secondWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
613 loc,
TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
614 Block &body = secondWarpOp.getBodyRegion().
front();
617 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
618 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
620 rewriter.
create<vector::YieldOp>(newWarpOp.getLoc());
624 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
626 auto yield = cast<vector::YieldOp>(
627 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
628 Operation *lastNode = yield->getPrevNode();
629 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
633 Value maybeMask = writeOp.getMask();
634 if (!llvm::all_of(writeOp->getOperands(), [&](
Value value) {
635 return writeOp.getVector() == value ||
636 (maybeMask && maybeMask == value) ||
637 warpOp.isDefinedOutsideOfRegion(value);
641 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
645 if (writeOp.getMask())
648 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
656 unsigned maxNumElementsToExtract = 1;
679 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
689 Value distributedVal = warpOp.getResult(operandIndex);
695 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
697 auto operandType = cast<VectorType>(operand.get().getType());
701 auto operandType = operand.get().getType();
702 assert(!isa<VectorType>(operandType) &&
703 "unexpected yield of vector from op with scalar result type");
704 targetType = operandType;
706 retTypes.push_back(targetType);
707 yieldValues.push_back(operand.get());
711 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
715 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
716 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
721 rewriter, loc, elementWise, newOperands,
722 {newWarpOp.getResult(operandIndex).
getType()});
745 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
752 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
761 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
764 Value distConstant = rewriter.
create<arith::ConstantOp>(loc, newAttr);
783 if (originalShape == distributedShape) {
784 delinearizedIds.clear();
789 for (
auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
790 if (large % small != 0)
792 sizes.push_back(large / small);
794 if (std::accumulate(sizes.begin(), sizes.end(), 1,
795 std::multiplies<int64_t>()) != warpSize)
801 int64_t usedThreads = 1;
803 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
804 delinearizedIds.assign(sizes.size(), zero);
806 for (
int i = sizes.size() - 1; i >= 0; --i) {
807 usedThreads *= sizes[i];
808 if (usedThreads == warpSize) {
811 delinearizedIds[i] = laneId;
817 builder, loc, s0.
floorDiv(usedThreads), {laneId});
842 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
849 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
853 warpOp,
"warp result is not a vector.transfer_read op");
857 if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
859 read,
"source must be defined outside of the region");
862 Value distributedVal = warpOp.getResult(operandIndex);
865 read.getIndices().end());
866 auto sequentialType = cast<VectorType>(read.getResult().getType());
867 auto distributedType = cast<VectorType>(distributedVal.
getType());
874 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
875 distributedType.getShape(), warpOp.getWarpSize(),
876 warpOp.getLaneid(), delinearizedIds)) {
878 read,
"cannot delinearize lane ID for distribution");
880 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
887 additionalResults.push_back(read.getPadding());
888 additionalResultTypes.push_back(read.getPadding().getType());
890 bool hasMask =
false;
891 if (read.getMask()) {
901 read,
"non-trivial permutation maps not supported");
902 VectorType maskType =
903 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
904 additionalResults.push_back(read.getMask());
905 additionalResultTypes.push_back(maskType);
910 rewriter, warpOp, additionalResults, additionalResultTypes,
912 distributedVal = newWarpOp.getResult(operandIndex);
916 for (int64_t i = 0, e = indices.size(); i < e; ++i)
917 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
922 bindDims(read.getContext(), d0, d1);
923 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
926 unsigned indexPos = indexExpr.getPosition();
927 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
928 int64_t scale = distributedType.getDimSize(vectorPos);
930 rewriter, read.getLoc(), d0 + scale * d1,
931 {newIndices[indexPos], delinearizedIds[vectorPos]});
935 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
938 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
940 auto newRead = rewriter.
create<vector::TransferReadOp>(
941 read.getLoc(), distributedVal.
getType(), read.getSource(), newIndices,
942 read.getPermutationMapAttr(), newPadding, newMask,
943 read.getInBoundsAttr());
954 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
957 newResultTypes.reserve(warpOp->getNumResults());
959 newYieldValues.reserve(warpOp->getNumResults());
962 auto yield = cast<vector::YieldOp>(
963 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
973 for (
OpResult result : warpOp.getResults()) {
974 Value yieldOperand = yield.getOperand(result.getResultNumber());
975 auto it = dedupYieldOperandPositionMap.insert(
976 std::make_pair(yieldOperand, newResultTypes.size()));
977 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
978 if (result.use_empty() || !it.second)
980 newResultTypes.push_back(result.getType());
981 newYieldValues.push_back(yieldOperand);
984 if (yield.getNumOperands() == newYieldValues.size())
988 rewriter, warpOp, newYieldValues, newResultTypes);
991 newWarpOp.getBody()->walk([&](
Operation *op) {
998 newValues.reserve(warpOp->getNumResults());
999 for (
OpResult result : warpOp.getResults()) {
1000 if (result.use_empty())
1001 newValues.push_back(
Value());
1003 newValues.push_back(
1004 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
1013 struct WarpOpForwardOperand :
public OpRewritePattern<WarpExecuteOnLane0Op> {
1015 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1019 auto yield = cast<vector::YieldOp>(
1020 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1022 unsigned resultIndex;
1023 for (
OpOperand &operand : yield->getOpOperands()) {
1032 valForwarded = operand.
get();
1036 auto arg = dyn_cast<BlockArgument>(operand.
get());
1037 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1039 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1042 valForwarded = warpOperand;
1059 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1067 Location loc = broadcastOp.getLoc();
1069 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1070 Value broadcastSrc = broadcastOp.getSource();
1082 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1084 Value broadcasted = rewriter.
create<vector::BroadcastOp>(
1085 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1096 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1106 auto castDistributedType =
1107 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1108 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1109 VectorType castResultType = castDistributedType;
1113 unsigned castDistributedRank = castDistributedType.getRank();
1114 unsigned castOriginalRank = castOriginalType.getRank();
1115 if (castDistributedRank < castOriginalRank) {
1117 llvm::append_range(shape, castDistributedType.getShape());
1118 castDistributedType =
1124 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1127 Value newCast = rewriter.
create<vector::ShapeCastOp>(
1128 oldCastOp.getLoc(), castResultType,
1129 newWarpOp->getResult(newRetIndices[0]));
1155 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1166 if (!llvm::all_of(mask->getOperands(), [&](
Value value) {
1167 return warpOp.isDefinedOutsideOfRegion(value);
1174 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1175 VectorType seqType = mask.getVectorType();
1183 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1184 warpOp.getWarpSize(), warpOp.getLaneid(),
1187 mask,
"cannot delinearize lane ID for distribution");
1188 assert(!delinearizedIds.empty());
1197 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1204 rewriter, loc, s1 - s0 * distShape[i],
1205 {delinearizedIds[i], mask.getOperand(i)});
1206 newOperands.push_back(maskDimIdx);
1210 rewriter.
create<vector::CreateMaskOp>(loc, distType, newOperands);
1221 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1229 VectorType extractSrcType = extractOp.getSourceVectorType();
1233 assert(extractSrcType.getRank() > 0 &&
1234 "vector.extract does not support rank 0 sources");
1238 if (extractOp.getNumIndices() == 0)
1242 if (extractSrcType.getRank() == 1) {
1243 if (extractOp.hasDynamicPosition())
1247 assert(extractOp.getNumIndices() == 1 &&
"expected 1 index");
1248 int64_t pos = extractOp.getStaticPosition()[0];
1251 extractOp, extractOp.getVector(),
1252 rewriter.
create<arith::ConstantIndexOp>(loc, pos));
1258 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1266 rewriter, warpOp, {extractOp.getVector()},
1267 {extractOp.getSourceVectorType()}, newRetIndices);
1269 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1271 Value newExtract = rewriter.
create<vector::ExtractOp>(
1272 loc, distributedVec, extractOp.getMixedPosition());
1279 auto distributedType =
1280 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1281 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1282 int64_t distributedDim = -1;
1283 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1284 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1287 assert(distributedDim == -1 &&
"found multiple distributed dims");
1291 assert(distributedDim != -1 &&
"could not find distributed dimension");
1292 (void)distributedDim;
1296 for (
int i = 0; i < distributedType.getRank(); ++i)
1297 newDistributedShape[i + extractOp.getNumIndices()] =
1298 distributedType.getDimSize(i);
1299 auto newDistributedType =
1300 VectorType::get(newDistributedShape, distributedType.getElementType());
1303 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1306 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1308 Value newExtract = rewriter.
create<vector::ExtractOp>(
1309 loc, distributedVec, extractOp.getMixedPosition());
1318 struct WarpOpExtractElement :
public OpRewritePattern<WarpExecuteOnLane0Op> {
1319 WarpOpExtractElement(
MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1322 warpShuffleFromIdxFn(std::move(fn)) {}
1323 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1326 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1331 VectorType extractSrcType = extractOp.getSourceVectorType();
1334 if (!extractSrcType.getElementType().isF32() &&
1335 !extractSrcType.getElementType().isInteger(32))
1337 extractOp,
"only f32/i32 element types are supported");
1338 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1339 Type elType = extractSrcType.getElementType();
1340 VectorType distributedVecType;
1341 if (!is0dOrVec1Extract) {
1342 assert(extractSrcType.getRank() == 1 &&
1343 "expected that extractelement src rank is 0 or 1");
1344 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1346 int64_t elementsPerLane =
1347 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1350 distributedVecType = extractSrcType;
1355 if (
static_cast<bool>(extractOp.getPosition())) {
1356 additionalResults.push_back(extractOp.getPosition());
1357 additionalResultTypes.push_back(extractOp.getPosition().getType());
1362 rewriter, warpOp, additionalResults, additionalResultTypes,
1365 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1369 if (is0dOrVec1Extract) {
1371 if (extractSrcType.getRank() == 1) {
1372 newExtract = rewriter.
create<vector::ExtractElementOp>(
1373 loc, distributedVec,
1374 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1378 rewriter.
create<vector::ExtractElementOp>(loc, distributedVec);
1387 int64_t elementsPerLane = distributedVecType.getShape()[0];
1390 Value broadcastFromTid = rewriter.
create<affine::AffineApplyOp>(
1391 loc, sym0.
ceilDiv(elementsPerLane),
1392 newWarpOp->getResult(newRetIndices[1]));
1395 elementsPerLane == 1
1396 ? rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult()
1398 .create<affine::AffineApplyOp>(
1399 loc, sym0 % elementsPerLane,
1400 newWarpOp->getResult(newRetIndices[1]))
1403 rewriter.
create<vector::ExtractElementOp>(loc, distributedVec, pos);
1406 Value shuffled = warpShuffleFromIdxFn(
1407 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1413 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1416 struct WarpOpInsertElement :
public OpRewritePattern<WarpExecuteOnLane0Op> {
1419 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1422 getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1427 VectorType vecType = insertOp.getDestVectorType();
1428 VectorType distrType =
1429 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1430 bool hasPos =
static_cast<bool>(insertOp.getPosition());
1434 insertOp.getSource()};
1436 insertOp.getSource().getType()};
1438 additionalResults.push_back(insertOp.getPosition());
1439 additionalResultTypes.push_back(insertOp.getPosition().getType());
1444 rewriter, warpOp, additionalResults, additionalResultTypes,
1447 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1448 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1449 Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) :
Value();
1452 if (vecType == distrType) {
1454 Value newInsert = rewriter.
create<vector::InsertElementOp>(
1455 loc, newSource, distributedVec, newPos);
1462 int64_t elementsPerLane = distrType.getShape()[0];
1465 Value insertingLane = rewriter.
create<affine::AffineApplyOp>(
1466 loc, sym0.
ceilDiv(elementsPerLane), newPos);
1469 elementsPerLane == 1
1470 ? rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult()
1472 .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1475 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1476 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1480 loc, isInsertingLane,
1483 Value newInsert = builder.create<vector::InsertElementOp>(
1484 loc, newSource, distributedVec, pos);
1485 builder.create<scf::YieldOp>(loc, newInsert);
1489 builder.create<scf::YieldOp>(loc, distributedVec);
1500 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1510 if (insertOp.getNumIndices() == 0)
1514 if (insertOp.getDestVectorType().getRank() == 1) {
1515 if (insertOp.hasDynamicPosition())
1519 assert(insertOp.getNumIndices() == 1 &&
"expected 1 index");
1520 int64_t pos = insertOp.getStaticPosition()[0];
1523 insertOp, insertOp.getSource(), insertOp.getDest(),
1524 rewriter.
create<arith::ConstantIndexOp>(loc, pos));
1528 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1533 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1534 {insertOp.getSourceType(), insertOp.getDestVectorType()},
1537 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1538 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1539 Value newResult = rewriter.
create<vector::InsertOp>(
1540 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1547 auto distrDestType =
1548 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1549 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1550 int64_t distrDestDim = -1;
1551 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1552 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1555 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1559 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1562 VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
1570 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1571 if (distrSrcDim >= 0)
1572 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1579 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1580 {distrSrcType, distrDestType}, newRetIndices);
1582 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1583 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1587 if (distrSrcDim >= 0) {
1589 newResult = rewriter.
create<vector::InsertOp>(
1590 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1593 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1597 Value insertingLane = rewriter.
create<arith::ConstantIndexOp>(
1598 loc, newPos[distrDestDim] / elementsPerLane);
1599 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1600 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1602 newPos[distrDestDim] %= elementsPerLane;
1604 Value newInsert = builder.
create<vector::InsertOp>(
1605 loc, distributedSrc, distributedDest, newPos);
1606 builder.
create<scf::YieldOp>(loc, newInsert);
1609 builder.
create<scf::YieldOp>(loc, distributedDest);
1611 newResult = rewriter
1612 .
create<scf::IfOp>(loc, isInsertingLane,
1614 nonInsertingBuilder)
1659 distributionMapFn(std::move(fn)) {}
1661 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1663 auto yield = cast<vector::YieldOp>(
1664 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1666 Operation *lastNode = yield->getPrevNode();
1667 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1673 llvm::SmallSetVector<Value, 32> escapingValues;
1677 forOp.getBodyRegion(), [&](
OpOperand *operand) {
1678 Operation *parent = operand->get().getParentRegion()->getParentOp();
1679 if (warpOp->isAncestor(parent)) {
1680 if (!escapingValues.insert(operand->get()))
1682 Type distType = operand->get().getType();
1683 if (auto vecType = dyn_cast<VectorType>(distType)) {
1684 AffineMap map = distributionMapFn(operand->get());
1685 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1687 inputTypes.push_back(operand->get().getType());
1688 distTypes.push_back(distType);
1692 if (llvm::is_contained(distTypes,
Type{}))
1697 rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1699 yield = cast<vector::YieldOp>(
1700 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1705 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
1708 auto forResult = cast<OpResult>(yieldOperand.
get());
1709 newOperands.push_back(
1711 yieldOperand.
set(forOp.getInitArgs()[forResult.getResultNumber()]);
1720 auto newForOp = rewriter.
create<scf::ForOp>(
1721 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1722 forOp.getStep(), newOperands);
1726 newForOp.getRegionIterArgs().end());
1728 forOp.getResultTypes().end());
1729 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1731 warpInput.push_back(newWarpOp.getResult(retIdx));
1732 argIndexMapping[escapingValues[i]] = warpInputType.size();
1733 warpInputType.push_back(inputTypes[i]);
1735 auto innerWarp = rewriter.
create<WarpExecuteOnLane0Op>(
1736 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1737 newWarpOp.getWarpSize(), warpInput, warpInputType);
1740 argMapping.push_back(newForOp.getInductionVar());
1741 for (
Value args : innerWarp.getBody()->getArguments()) {
1742 argMapping.push_back(args);
1744 argMapping.resize(forOp.getBody()->getNumArguments());
1746 for (
Value operand : forOp.getBody()->getTerminator()->getOperands())
1747 yieldOperands.push_back(operand);
1748 rewriter.
eraseOp(forOp.getBody()->getTerminator());
1749 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1751 rewriter.
create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
1753 if (!innerWarp.getResults().empty())
1754 rewriter.
create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1759 newForOp.getResult(res.index()));
1760 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1764 auto it = argIndexMapping.find(operand.
get());
1765 if (it == argIndexMapping.end())
1767 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
1772 mlir::vector::moveScalarUniformCode(innerWarp);
1801 DistributedReductionFn distributedReductionFn,
1804 distributedReductionFn(std::move(distributedReductionFn)) {}
1806 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1815 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1817 if (vectorType.getRank() != 1)
1819 warpOp,
"Only rank 1 reductions can be distributed.");
1821 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1823 warpOp,
"Reduction vector dimension must match was size.");
1824 if (!reductionOp.getType().isIntOrFloat())
1826 warpOp,
"Reduction distribution currently only supports floats and "
1829 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1835 if (reductionOp.getAcc()) {
1836 yieldValues.push_back(reductionOp.getAcc());
1837 retTypes.push_back(reductionOp.getAcc().getType());
1841 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1845 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1848 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1849 reductionOp.getKind(), newWarpOp.getWarpSize());
1850 if (reductionOp.getAcc()) {
1852 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1853 newWarpOp.getResult(newRetIndices[1]));
1860 DistributedReductionFn distributedReductionFn;
1871 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1874 patterns.
add<WarpOpTransferWrite>(patterns.
getContext(), distributionMapFn,
1875 maxNumElementsToExtract, benefit);
1878 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1880 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
1882 patterns.
add<WarpOpTransferRead>(patterns.
getContext(), readBenefit);
1884 .
add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1885 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
1886 WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
1889 warpShuffleFromIdxFn, benefit);
1890 patterns.
add<WarpOpScfForOp>(patterns.
getContext(), distributionMapFn,
1894 void mlir::vector::populateDistributeReduction(
1896 const DistributedReductionFn &distributedReductionFn,
1898 patterns.
add<WarpOpReduction>(patterns.
getContext(), distributedReductionFn,
1902 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1903 Block *body = warpOp.getBody();
1906 llvm::SmallSetVector<Operation *, 8> opsToMove;
1909 auto isDefinedOutsideOfBody = [&](
Value value) {
1911 return (definingOp && opsToMove.count(definingOp)) ||
1912 warpOp.isDefinedOutsideOfRegion(value);
1918 bool hasVectorResult = llvm::any_of(op.
getResults(), [](
Value result) {
1919 return isa<VectorType>(result.getType());
1921 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
1922 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, const std::function< bool(Operation *)> &fn)
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes)
Helper to create a new WarpExecuteOnLane0Op with different signature.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, llvm::SmallVector< size_t > &indices)
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.