18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/Support/FormatVariadic.h"
38 VectorType distributedType) {
44 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
45 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
49 distributedType.getContext());
60 struct DistributedLoadStoreHelper {
61 DistributedLoadStoreHelper(
Value sequentialVal,
Value distributedVal,
63 : sequentialVal(sequentialVal), distributedVal(distributedVal),
64 laneId(laneId), zero(zero) {
65 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.
getType());
66 distributedVectorType = dyn_cast<VectorType>(distributedVal.
getType());
67 if (sequentialVectorType && distributedVectorType)
73 int64_t distributedSize = distributedVectorType.getDimSize(index);
75 return b.
createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
88 assert((val == distributedVal || val == sequentialVal) &&
89 "Must store either the preregistered distributed or the "
90 "preregistered sequential value.");
92 if (!isa<VectorType>(val.
getType()))
93 return b.
create<memref::StoreOp>(loc, val, buffer, zero);
97 int64_t rank = sequentialVectorType.getRank();
99 if (val == distributedVal) {
100 for (
auto dimExpr : distributionMap.getResults()) {
101 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
102 indices[index] = buildDistributedOffset(b, loc, index);
106 return b.
create<vector::TransferWriteOp>(
107 loc, val, buffer, indices,
134 if (!isa<VectorType>(type))
135 return b.
create<memref::LoadOp>(loc, buffer, zero);
140 assert((type == distributedVectorType || type == sequentialVectorType) &&
141 "Must store either the preregistered distributed or the "
142 "preregistered sequential type.");
144 if (type == distributedVectorType) {
145 for (
auto dimExpr : distributionMap.getResults()) {
146 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
147 indices[index] = buildDistributedOffset(b, loc, index);
151 return b.
create<vector::TransferReadOp>(
152 loc, cast<VectorType>(type), buffer, indices,
156 Value sequentialVal, distributedVal, laneId, zero;
157 VectorType sequentialVectorType, distributedVectorType;
170 auto newWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
171 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
172 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
174 Region &opBody = warpOp.getBodyRegion();
175 Region &newOpBody = newWarpOp.getBodyRegion();
179 assert(newWarpOp.getWarpRegion().hasOneBlock() &&
180 "expected WarpOp with single block");
183 cast<vector::YieldOp>(newOpBody.
getBlocks().begin()->getTerminator());
186 yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
197 warpOp.getResultTypes().end());
198 auto yield = cast<vector::YieldOp>(
199 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
200 llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
201 yield.getOperands().end());
202 for (
auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
203 if (yieldValues.insert(std::get<0>(newRet))) {
204 types.push_back(std::get<1>(newRet));
205 indices.push_back(yieldValues.size() - 1);
208 for (
auto [idx, yieldOperand] :
210 if (yieldOperand == std::get<0>(newRet)) {
211 indices.push_back(idx);
217 yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
219 rewriter, warpOp, yieldValues.getArrayRef(), types);
221 newWarpOp.getResults().take_front(warpOp.getNumResults()));
228 return llvm::all_of(op->
getOperands(), definedOutside) &&
235 const std::function<
bool(
Operation *)> &fn) {
236 auto yield = cast<vector::YieldOp>(
237 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
238 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
239 Value yieldValues = yieldOperand.get();
241 if (definedOp && fn(definedOp)) {
242 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
243 return &yieldOperand;
257 return rewriter.
create(res);
290 struct WarpOpToScfIfPattern :
public OpRewritePattern<WarpExecuteOnLane0Op> {
299 assert(warpOp.getBodyRegion().hasOneBlock() &&
300 "expected WarpOp with single block");
301 Block *warpOpBody = &warpOp.getBodyRegion().
front();
309 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
311 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
312 auto ifOp = rewriter.
create<scf::IfOp>(loc, isLane0,
314 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
321 Value distributedVal = it.value();
322 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
323 warpOp.getLaneid(), c0);
327 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
330 helper.buildStore(rewriter, loc, distributedVal, buffer);
333 bbArgReplacements.push_back(
334 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
338 if (!warpOp.getArgs().empty()) {
340 options.warpSyncronizationFn(loc, rewriter, warpOp);
344 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
351 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
352 Location yieldLoc = yieldOp.getLoc();
354 Value sequentialVal = it.value();
355 Value distributedVal = warpOp->getResult(it.index());
356 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
357 warpOp.getLaneid(), c0);
361 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
367 helper.buildStore(rewriter, loc, sequentialVal, buffer);
378 replacements.push_back(
379 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
383 if (!yieldOp.getOperands().empty()) {
385 options.warpSyncronizationFn(loc, rewriter, warpOp);
391 rewriter.
create<scf::YieldOp>(yieldLoc);
394 rewriter.
replaceOp(warpOp, replacements);
407 static vector::TransferWriteOp cloneWriteOp(
RewriterBase &rewriter,
408 WarpExecuteOnLane0Op warpOp,
409 vector::TransferWriteOp writeOp,
410 VectorType targetType,
411 VectorType maybeMaskType) {
412 assert(writeOp->getParentOp() == warpOp &&
413 "write must be nested immediately under warp");
416 WarpExecuteOnLane0Op newWarpOp;
419 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
420 TypeRange{targetType, maybeMaskType}, newRetIndices);
423 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
428 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
430 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
432 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
444 static VectorType getDistributedType(VectorType originalType,
AffineMap map,
447 originalType.getShape().end());
450 if (targetShape[position] % warpSize != 0) {
451 if (warpSize % targetShape[position] != 0) {
454 warpSize /= targetShape[position];
455 targetShape[position] = 1;
458 targetShape[position] = targetShape[position] / warpSize;
465 VectorType targetType =
489 struct WarpOpTransferWrite :
public OpRewritePattern<WarpExecuteOnLane0Op> {
493 distributionMapFn(std::move(fn)),
494 maxNumElementsToExtract(maxNumElementsToExtract) {}
499 vector::TransferWriteOp writeOp,
500 WarpExecuteOnLane0Op warpOp)
const {
501 VectorType writtenVectorType = writeOp.getVectorType();
505 if (writtenVectorType.getRank() == 0)
509 AffineMap map = distributionMapFn(writeOp.getVector());
510 VectorType targetType =
511 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
517 if (writeOp.getMask()) {
524 if (!writeOp.getPermutationMap().isMinorIdentity())
527 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
532 vector::TransferWriteOp newWriteOp =
533 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
537 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
544 for (
auto [seqSize, distSize] :
545 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
546 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
547 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
551 delinearized = rewriter
552 .
create<mlir::affine::AffineDelinearizeIndexOp>(
553 newWarpOp.getLoc(), newWarpOp.getLaneid(),
559 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
565 newWriteOp.getIndices().end());
568 bindDims(newWarpOp.getContext(), d0, d1);
569 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
572 unsigned indexPos = indexExpr.getPosition();
573 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
574 Value laneId = delinearized[vectorPos];
578 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
580 newWriteOp.getIndicesMutable().assign(indices);
587 vector::TransferWriteOp writeOp,
588 WarpExecuteOnLane0Op warpOp)
const {
590 VectorType vecType = writeOp.getVectorType();
592 if (vecType.getNumElements() > maxNumElementsToExtract) {
596 "writes more elements ({0}) than allowed to extract ({1})",
597 vecType.getNumElements(), maxNumElementsToExtract));
601 if (llvm::all_of(warpOp.getOps(),
602 llvm::IsaPred<vector::TransferWriteOp, vector::YieldOp>))
609 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
613 auto secondWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
614 loc,
TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
615 Block &body = secondWarpOp.getBodyRegion().
front();
618 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
619 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
621 rewriter.
create<vector::YieldOp>(newWarpOp.getLoc());
627 auto yield = cast<vector::YieldOp>(
628 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
629 Operation *lastNode = yield->getPrevNode();
630 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
634 Value maybeMask = writeOp.getMask();
635 if (!llvm::all_of(writeOp->getOperands(), [&](
Value value) {
636 return writeOp.getVector() == value ||
637 (maybeMask && maybeMask == value) ||
638 warpOp.isDefinedOutsideOfRegion(value);
642 if (
succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
646 if (writeOp.getMask())
649 if (
succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
657 unsigned maxNumElementsToExtract = 1;
690 Value distributedVal = warpOp.getResult(operandIndex);
696 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
698 auto operandType = cast<VectorType>(operand.get().getType());
702 auto operandType = operand.get().getType();
703 assert(!isa<VectorType>(operandType) &&
704 "unexpected yield of vector from op with scalar result type");
705 targetType = operandType;
707 retTypes.push_back(targetType);
708 yieldValues.push_back(operand.get());
712 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
716 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
717 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
722 rewriter, loc, elementWise, newOperands,
723 {newWarpOp.getResult(operandIndex).getType()});
753 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
762 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
765 Value distConstant = rewriter.
create<arith::ConstantOp>(loc, newAttr);
784 if (originalShape == distributedShape) {
785 delinearizedIds.clear();
790 for (
auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
791 if (large % small != 0)
793 sizes.push_back(large / small);
795 if (std::accumulate(sizes.begin(), sizes.end(), 1,
796 std::multiplies<int64_t>()) != warpSize)
802 int64_t usedThreads = 1;
804 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
805 delinearizedIds.assign(sizes.size(), zero);
807 for (
int i = sizes.size() - 1; i >= 0; --i) {
808 usedThreads *= sizes[i];
809 if (usedThreads == warpSize) {
812 delinearizedIds[i] = laneId;
818 builder, loc, s0.
floorDiv(usedThreads), {laneId});
850 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
854 warpOp,
"warp result is not a vector.transfer_read op");
858 if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
860 read,
"source must be defined outside of the region");
863 Value distributedVal = warpOp.getResult(operandIndex);
866 read.getIndices().end());
867 auto sequentialType = cast<VectorType>(read.getResult().getType());
868 auto distributedType = cast<VectorType>(distributedVal.
getType());
875 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
876 distributedType.getShape(), warpOp.getWarpSize(),
877 warpOp.getLaneid(), delinearizedIds)) {
879 read,
"cannot delinearize lane ID for distribution");
881 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
888 additionalResults.push_back(read.getPadding());
889 additionalResultTypes.push_back(read.getPadding().getType());
891 bool hasMask =
false;
892 if (read.getMask()) {
902 read,
"non-trivial permutation maps not supported");
903 VectorType maskType =
904 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
905 additionalResults.push_back(read.getMask());
906 additionalResultTypes.push_back(maskType);
911 rewriter, warpOp, additionalResults, additionalResultTypes,
913 distributedVal = newWarpOp.getResult(operandIndex);
917 for (int64_t i = 0, e = indices.size(); i < e; ++i)
918 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
923 bindDims(read.getContext(), d0, d1);
924 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
927 unsigned indexPos = indexExpr.getPosition();
928 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
929 int64_t scale = distributedType.getDimSize(vectorPos);
931 rewriter, read.getLoc(), d0 + scale * d1,
932 {newIndices[indexPos], delinearizedIds[vectorPos]});
936 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
939 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
941 auto newRead = rewriter.
create<vector::TransferReadOp>(
942 read.getLoc(), distributedVal.
getType(), read.getSource(), newIndices,
943 read.getPermutationMapAttr(), newPadding, newMask,
944 read.getInBoundsAttr());
958 newResultTypes.reserve(warpOp->getNumResults());
960 newYieldValues.reserve(warpOp->getNumResults());
963 auto yield = cast<vector::YieldOp>(
964 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
974 for (
OpResult result : warpOp.getResults()) {
975 Value yieldOperand = yield.getOperand(result.getResultNumber());
976 auto it = dedupYieldOperandPositionMap.insert(
977 std::make_pair(yieldOperand, newResultTypes.size()));
978 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
979 if (result.use_empty() || !it.second)
981 newResultTypes.push_back(result.getType());
982 newYieldValues.push_back(yieldOperand);
985 if (yield.getNumOperands() == newYieldValues.size())
989 rewriter, warpOp, newYieldValues, newResultTypes);
992 newWarpOp.getBody()->walk([&](
Operation *op) {
999 newValues.reserve(warpOp->getNumResults());
1000 for (
OpResult result : warpOp.getResults()) {
1001 if (result.use_empty())
1002 newValues.push_back(
Value());
1004 newValues.push_back(
1005 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
1014 struct WarpOpForwardOperand :
public OpRewritePattern<WarpExecuteOnLane0Op> {
1020 auto yield = cast<vector::YieldOp>(
1021 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1023 unsigned resultIndex;
1024 for (
OpOperand &operand : yield->getOpOperands()) {
1033 valForwarded = operand.
get();
1037 auto arg = dyn_cast<BlockArgument>(operand.
get());
1038 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1040 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1043 valForwarded = warpOperand;
1068 Location loc = broadcastOp.getLoc();
1070 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1071 Value broadcastSrc = broadcastOp.getSource();
1083 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1085 Value broadcasted = rewriter.
create<vector::BroadcastOp>(
1086 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1107 auto castDistributedType =
1108 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1109 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1110 VectorType castResultType = castDistributedType;
1114 unsigned castDistributedRank = castDistributedType.getRank();
1115 unsigned castOriginalRank = castOriginalType.getRank();
1116 if (castDistributedRank < castOriginalRank) {
1118 llvm::append_range(shape, castDistributedType.getShape());
1119 castDistributedType =
1125 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1128 Value newCast = rewriter.
create<vector::ShapeCastOp>(
1129 oldCastOp.getLoc(), castResultType,
1130 newWarpOp->getResult(newRetIndices[0]));
1167 if (!llvm::all_of(mask->getOperands(), [&](
Value value) {
1168 return warpOp.isDefinedOutsideOfRegion(value);
1175 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1176 VectorType seqType = mask.getVectorType();
1184 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1185 warpOp.getWarpSize(), warpOp.getLaneid(),
1188 mask,
"cannot delinearize lane ID for distribution");
1189 assert(!delinearizedIds.empty());
1198 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1205 rewriter, loc, s1 - s0 * distShape[i],
1206 {delinearizedIds[i], mask.getOperand(i)});
1207 newOperands.push_back(maskDimIdx);
1211 rewriter.
create<vector::CreateMaskOp>(loc, distType, newOperands);
1230 VectorType extractSrcType = extractOp.getSourceVectorType();
1234 assert(extractSrcType.getRank() > 0 &&
1235 "vector.extract does not support rank 0 sources");
1239 if (extractOp.getNumIndices() == 0)
1243 if (extractSrcType.getRank() == 1) {
1244 if (extractOp.hasDynamicPosition())
1248 assert(extractOp.getNumIndices() == 1 &&
"expected 1 index");
1249 int64_t pos = extractOp.getStaticPosition()[0];
1252 extractOp, extractOp.getVector(),
1253 rewriter.
create<arith::ConstantIndexOp>(loc, pos));
1259 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1267 rewriter, warpOp, {extractOp.getVector()},
1268 {extractOp.getSourceVectorType()}, newRetIndices);
1270 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1272 Value newExtract = rewriter.
create<vector::ExtractOp>(
1273 loc, distributedVec, extractOp.getMixedPosition());
1280 auto distributedType =
1281 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1282 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1283 int64_t distributedDim = -1;
1284 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1285 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1288 assert(distributedDim == -1 &&
"found multiple distributed dims");
1292 assert(distributedDim != -1 &&
"could not find distributed dimension");
1293 (void)distributedDim;
1297 extractSrcType.getShape().end());
1298 for (
int i = 0; i < distributedType.getRank(); ++i)
1299 newDistributedShape[i + extractOp.getNumIndices()] =
1300 distributedType.getDimSize(i);
1301 auto newDistributedType =
1302 VectorType::get(newDistributedShape, distributedType.getElementType());
1305 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1308 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1310 Value newExtract = rewriter.
create<vector::ExtractOp>(
1311 loc, distributedVec, extractOp.getMixedPosition());
1320 struct WarpOpExtractElement :
public OpRewritePattern<WarpExecuteOnLane0Op> {
1321 WarpOpExtractElement(
MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1324 warpShuffleFromIdxFn(std::move(fn)) {}
1328 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1333 VectorType extractSrcType = extractOp.getSourceVectorType();
1336 if (!extractSrcType.getElementType().isF32() &&
1337 !extractSrcType.getElementType().isInteger(32))
1339 extractOp,
"only f32/i32 element types are supported");
1340 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1341 Type elType = extractSrcType.getElementType();
1342 VectorType distributedVecType;
1343 if (!is0dOrVec1Extract) {
1344 assert(extractSrcType.getRank() == 1 &&
1345 "expected that extractelement src rank is 0 or 1");
1346 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1348 int64_t elementsPerLane =
1349 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1352 distributedVecType = extractSrcType;
1357 if (
static_cast<bool>(extractOp.getPosition())) {
1358 additionalResults.push_back(extractOp.getPosition());
1359 additionalResultTypes.push_back(extractOp.getPosition().getType());
1364 rewriter, warpOp, additionalResults, additionalResultTypes,
1367 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1371 if (is0dOrVec1Extract) {
1373 if (extractSrcType.getRank() == 1) {
1374 newExtract = rewriter.
create<vector::ExtractElementOp>(
1375 loc, distributedVec,
1376 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1380 rewriter.
create<vector::ExtractElementOp>(loc, distributedVec);
1389 int64_t elementsPerLane = distributedVecType.getShape()[0];
1392 Value broadcastFromTid = rewriter.
create<affine::AffineApplyOp>(
1393 loc, sym0.
ceilDiv(elementsPerLane),
1394 newWarpOp->getResult(newRetIndices[1]));
1397 elementsPerLane == 1
1398 ? rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult()
1400 .create<affine::AffineApplyOp>(
1401 loc, sym0 % elementsPerLane,
1402 newWarpOp->getResult(newRetIndices[1]))
1405 rewriter.
create<vector::ExtractElementOp>(loc, distributedVec, pos);
1408 Value shuffled = warpShuffleFromIdxFn(
1409 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1415 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1418 struct WarpOpInsertElement :
public OpRewritePattern<WarpExecuteOnLane0Op> {
1424 getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1429 VectorType vecType = insertOp.getDestVectorType();
1430 VectorType distrType =
1431 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1432 bool hasPos =
static_cast<bool>(insertOp.getPosition());
1436 insertOp.getSource()};
1438 insertOp.getSource().getType()};
1440 additionalResults.push_back(insertOp.getPosition());
1441 additionalResultTypes.push_back(insertOp.getPosition().getType());
1446 rewriter, warpOp, additionalResults, additionalResultTypes,
1449 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1450 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1451 Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) :
Value();
1454 if (vecType == distrType) {
1456 Value newInsert = rewriter.
create<vector::InsertElementOp>(
1457 loc, newSource, distributedVec, newPos);
1464 int64_t elementsPerLane = distrType.getShape()[0];
1467 Value insertingLane = rewriter.
create<affine::AffineApplyOp>(
1468 loc, sym0.
ceilDiv(elementsPerLane), newPos);
1471 elementsPerLane == 1
1472 ? rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult()
1474 .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1477 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1478 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1482 loc, isInsertingLane,
1485 Value newInsert = builder.create<vector::InsertElementOp>(
1486 loc, newSource, distributedVec, pos);
1487 builder.create<scf::YieldOp>(loc, newInsert);
1491 builder.create<scf::YieldOp>(loc, distributedVec);
1512 if (insertOp.getNumIndices() == 0)
1516 if (insertOp.getDestVectorType().getRank() == 1) {
1517 if (insertOp.hasDynamicPosition())
1521 assert(insertOp.getNumIndices() == 1 &&
"expected 1 index");
1522 int64_t pos = insertOp.getStaticPosition()[0];
1525 insertOp, insertOp.getSource(), insertOp.getDest(),
1526 rewriter.
create<arith::ConstantIndexOp>(loc, pos));
1530 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1535 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1536 {insertOp.getSourceType(), insertOp.getDestVectorType()},
1539 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1540 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1541 Value newResult = rewriter.
create<vector::InsertOp>(
1542 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1549 auto distrDestType =
1550 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1551 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1552 int64_t distrDestDim = -1;
1553 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1554 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1557 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1561 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1564 VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
1566 srcVecType.getShape().end());
1573 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1574 if (distrSrcDim >= 0)
1575 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1582 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1583 {distrSrcType, distrDestType}, newRetIndices);
1585 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1586 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1590 if (distrSrcDim >= 0) {
1592 newResult = rewriter.
create<vector::InsertOp>(
1593 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1596 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1600 Value insertingLane = rewriter.
create<arith::ConstantIndexOp>(
1601 loc, newPos[distrDestDim] / elementsPerLane);
1602 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1603 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1605 newPos[distrDestDim] %= elementsPerLane;
1607 Value newInsert = builder.
create<vector::InsertOp>(
1608 loc, distributedSrc, distributedDest, newPos);
1609 builder.
create<scf::YieldOp>(loc, newInsert);
1612 builder.
create<scf::YieldOp>(loc, distributedDest);
1614 newResult = rewriter
1615 .
create<scf::IfOp>(loc, isInsertingLane,
1617 nonInsertingBuilder)
1662 distributionMapFn(std::move(fn)) {}
1666 auto yield = cast<vector::YieldOp>(
1667 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1669 Operation *lastNode = yield->getPrevNode();
1670 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1676 llvm::SmallSetVector<Value, 32> escapingValues;
1680 forOp.getBodyRegion(), [&](
OpOperand *operand) {
1681 Operation *parent = operand->get().getParentRegion()->getParentOp();
1682 if (warpOp->isAncestor(parent)) {
1683 if (!escapingValues.insert(operand->get()))
1685 Type distType = operand->get().getType();
1686 if (auto vecType = dyn_cast<VectorType>(distType)) {
1687 AffineMap map = distributionMapFn(operand->get());
1688 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1690 inputTypes.push_back(operand->get().getType());
1691 distTypes.push_back(distType);
1697 rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1699 yield = cast<vector::YieldOp>(
1700 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1705 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
1708 auto forResult = cast<OpResult>(yieldOperand.
get());
1709 newOperands.push_back(
1711 yieldOperand.
set(forOp.getInitArgs()[forResult.getResultNumber()]);
1720 auto newForOp = rewriter.
create<scf::ForOp>(
1721 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1722 forOp.getStep(), newOperands);
1726 newForOp.getRegionIterArgs().end());
1728 forOp.getResultTypes().end());
1729 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1731 warpInput.push_back(newWarpOp.getResult(retIdx));
1732 argIndexMapping[escapingValues[i]] = warpInputType.size();
1733 warpInputType.push_back(inputTypes[i]);
1735 auto innerWarp = rewriter.
create<WarpExecuteOnLane0Op>(
1736 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1737 newWarpOp.getWarpSize(), warpInput, warpInputType);
1740 argMapping.push_back(newForOp.getInductionVar());
1741 for (
Value args : innerWarp.getBody()->getArguments()) {
1742 argMapping.push_back(args);
1744 argMapping.resize(forOp.getBody()->getNumArguments());
1746 for (
Value operand : forOp.getBody()->getTerminator()->getOperands())
1747 yieldOperands.push_back(operand);
1748 rewriter.
eraseOp(forOp.getBody()->getTerminator());
1749 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1751 rewriter.
create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
1753 if (!innerWarp.getResults().empty())
1754 rewriter.
create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1759 newForOp.getResult(res.index()));
1760 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1764 auto it = argIndexMapping.find(operand.
get());
1765 if (it == argIndexMapping.end())
1767 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
1772 mlir::vector::moveScalarUniformCode(innerWarp);
1801 DistributedReductionFn distributedReductionFn,
1804 distributedReductionFn(std::move(distributedReductionFn)) {}
1815 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1817 if (vectorType.getRank() != 1)
1819 warpOp,
"Only rank 1 reductions can be distributed.");
1821 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1823 warpOp,
"Reduction vector dimension must match was size.");
1824 if (!reductionOp.getType().isIntOrFloat())
1826 warpOp,
"Reduction distribution currently only supports floats and "
1829 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1835 if (reductionOp.getAcc()) {
1836 yieldValues.push_back(reductionOp.getAcc());
1837 retTypes.push_back(reductionOp.getAcc().getType());
1841 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1845 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1848 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1849 reductionOp.getKind(), newWarpOp.getWarpSize());
1850 if (reductionOp.getAcc()) {
1852 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1853 newWarpOp.getResult(newRetIndices[1]));
1860 DistributedReductionFn distributedReductionFn;
1871 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1874 patterns.
add<WarpOpTransferWrite>(patterns.
getContext(), distributionMapFn,
1875 maxNumElementsToExtract, benefit);
1878 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1880 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
1882 patterns.
add<WarpOpTransferRead>(patterns.
getContext(), readBenefit);
1884 .
add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1885 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
1886 WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
1889 warpShuffleFromIdxFn, benefit);
1890 patterns.
add<WarpOpScfForOp>(patterns.
getContext(), distributionMapFn,
1894 void mlir::vector::populateDistributeReduction(
1896 const DistributedReductionFn &distributedReductionFn,
1898 patterns.
add<WarpOpReduction>(patterns.
getContext(), distributedReductionFn,
1902 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1903 Block *body = warpOp.getBody();
1906 llvm::SmallSetVector<Operation *, 8> opsToMove;
1909 auto isDefinedOutsideOfBody = [&](
Value value) {
1911 return (definingOp && opsToMove.count(definingOp)) ||
1912 warpOp.isDefinedOutsideOfRegion(value);
1918 bool hasVectorResult = llvm::any_of(op.
getResults(), [](
Value result) {
1919 return isa<VectorType>(result.getType());
1921 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
1922 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, const std::function< bool(Operation *)> &fn)
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes)
Helper to create a new WarpExecuteOnLane0Op with different signature.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, llvm::SmallVector< size_t > &indices)
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.