18 #include "llvm/ADT/SetVector.h"
37 VectorType distributedType) {
43 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
44 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
48 distributedType.getContext());
59 struct DistributedLoadStoreHelper {
60 DistributedLoadStoreHelper(
Value sequentialVal,
Value distributedVal,
62 : sequentialVal(sequentialVal), distributedVal(distributedVal),
63 laneId(laneId), zero(zero) {
64 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.
getType());
65 distributedVectorType = dyn_cast<VectorType>(distributedVal.
getType());
66 if (sequentialVectorType && distributedVectorType)
72 int64_t distributedSize = distributedVectorType.getDimSize(index);
74 return b.
createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
87 assert((val == distributedVal || val == sequentialVal) &&
88 "Must store either the preregistered distributed or the "
89 "preregistered sequential value.");
91 if (!isa<VectorType>(val.
getType()))
92 return b.
create<memref::StoreOp>(loc, val, buffer, zero);
96 int64_t rank = sequentialVectorType.getRank();
98 if (val == distributedVal) {
99 for (
auto dimExpr : distributionMap.getResults()) {
101 indices[index] = buildDistributedOffset(b, loc, index);
105 return b.
create<vector::TransferWriteOp>(
106 loc, val, buffer, indices,
133 if (!isa<VectorType>(type))
134 return b.
create<memref::LoadOp>(loc, buffer, zero);
139 assert((type == distributedVectorType || type == sequentialVectorType) &&
140 "Must store either the preregistered distributed or the "
141 "preregistered sequential type.");
143 if (type == distributedVectorType) {
144 for (
auto dimExpr : distributionMap.getResults()) {
146 indices[index] = buildDistributedOffset(b, loc, index);
150 return b.
create<vector::TransferReadOp>(
151 loc, cast<VectorType>(type), buffer, indices,
155 Value sequentialVal, distributedVal, laneId, zero;
156 VectorType sequentialVectorType, distributedVectorType;
169 auto newWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
170 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
171 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
173 Region &opBody = warpOp.getBodyRegion();
174 Region &newOpBody = newWarpOp.getBodyRegion();
178 assert(newWarpOp.getWarpRegion().hasOneBlock() &&
179 "expected WarpOp with single block");
182 cast<vector::YieldOp>(newOpBody.
getBlocks().begin()->getTerminator());
185 yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
196 warpOp.getResultTypes().end());
197 auto yield = cast<vector::YieldOp>(
198 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
199 llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
200 yield.getOperands().end());
201 for (
auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
202 if (yieldValues.insert(std::get<0>(newRet))) {
203 types.push_back(std::get<1>(newRet));
204 indices.push_back(yieldValues.size() - 1);
207 for (
auto [idx, yieldOperand] :
209 if (yieldOperand == std::get<0>(newRet)) {
210 indices.push_back(idx);
216 yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
218 rewriter, warpOp, yieldValues.getArrayRef(), types);
220 newWarpOp.getResults().take_front(warpOp.getNumResults()));
227 return llvm::all_of(op->
getOperands(), definedOutside) &&
234 const std::function<
bool(
Operation *)> &fn) {
235 auto yield = cast<vector::YieldOp>(
236 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
237 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
238 Value yieldValues = yieldOperand.get();
240 if (definedOp && fn(definedOp)) {
241 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
242 return &yieldOperand;
256 return rewriter.
create(res);
289 struct WarpOpToScfIfPattern :
public OpRewritePattern<WarpExecuteOnLane0Op> {
298 assert(warpOp.getBodyRegion().hasOneBlock() &&
299 "expected WarpOp with single block");
300 Block *warpOpBody = &warpOp.getBodyRegion().
front();
308 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
310 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
311 auto ifOp = rewriter.
create<scf::IfOp>(loc, isLane0,
313 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
320 Value distributedVal = it.value();
321 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
322 warpOp.getLaneid(), c0);
326 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
329 helper.buildStore(rewriter, loc, distributedVal, buffer);
332 bbArgReplacements.push_back(
333 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
337 if (!warpOp.getArgs().empty()) {
339 options.warpSyncronizationFn(loc, rewriter, warpOp);
343 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
350 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
351 Location yieldLoc = yieldOp.getLoc();
353 Value sequentialVal = it.value();
354 Value distributedVal = warpOp->getResult(it.index());
355 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
356 warpOp.getLaneid(), c0);
360 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
366 helper.buildStore(rewriter, loc, sequentialVal, buffer);
377 replacements.push_back(
378 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
382 if (!yieldOp.getOperands().empty()) {
384 options.warpSyncronizationFn(loc, rewriter, warpOp);
390 rewriter.
create<scf::YieldOp>(yieldLoc);
393 rewriter.
replaceOp(warpOp, replacements);
406 static vector::TransferWriteOp cloneWriteOp(
RewriterBase &rewriter,
407 WarpExecuteOnLane0Op warpOp,
408 vector::TransferWriteOp writeOp,
409 VectorType targetType) {
410 assert(writeOp->getParentOp() == warpOp &&
411 "write must be nested immediately under warp");
415 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
419 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
421 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
433 static VectorType getDistributedType(VectorType originalType,
AffineMap map,
438 originalType.getShape().end());
441 if (targetShape[position] % warpSize != 0)
443 targetShape[position] = targetShape[position] / warpSize;
445 VectorType targetType =
467 struct WarpOpTransferWrite :
public OpRewritePattern<vector::TransferWriteOp> {
471 distributionMapFn(std::move(fn)) {}
476 vector::TransferWriteOp writeOp,
477 WarpExecuteOnLane0Op warpOp)
const {
478 VectorType writtenVectorType = writeOp.getVectorType();
482 if (writtenVectorType.getRank() == 0)
486 AffineMap map = distributionMapFn(writeOp.getVector());
487 VectorType targetType =
488 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
494 vector::TransferWriteOp newWriteOp =
495 cloneWriteOp(rewriter, warpOp, writeOp, targetType);
499 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
504 newWriteOp.getIndices().end());
507 bindDims(newWarpOp.getContext(), d0, d1);
512 unsigned vectorPos = std::get<1>(it).cast<
AffineDimExpr>().getPosition();
516 rewriter, loc, d0 + scale * d1,
517 {indices[indexPos], newWarpOp.getLaneid()});
519 newWriteOp.getIndicesMutable().assign(indices);
526 vector::TransferWriteOp writeOp,
527 WarpExecuteOnLane0Op warpOp)
const {
529 VectorType vecType = writeOp.getVectorType();
533 if (vecType.getNumElements() != 1)
537 if (llvm::all_of(warpOp.getOps(), [](
Operation &op) {
538 return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
546 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
550 auto secondWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
551 loc,
TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
552 Block &body = secondWarpOp.getBodyRegion().
front();
555 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
556 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
558 rewriter.
create<vector::YieldOp>(newWarpOp.getLoc());
562 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
565 if (writeOp.getMask())
568 auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
573 Operation *nextOp = writeOp.getOperation();
574 while ((nextOp = nextOp->getNextNode()))
578 if (!llvm::all_of(writeOp->getOperands(), [&](
Value value) {
579 return writeOp.getVector() == value ||
580 warpOp.isDefinedOutsideOfRegion(value);
584 if (
succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
587 if (
succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
626 Value distributedVal = warpOp.getResult(operandIndex);
632 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
634 auto operandType = cast<VectorType>(operand.get().getType());
638 auto operandType = operand.get().getType();
639 assert(!isa<VectorType>(operandType) &&
640 "unexpected yield of vector from op with scalar result type");
641 targetType = operandType;
643 retTypes.push_back(targetType);
644 yieldValues.push_back(operand.get());
648 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
652 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
653 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
658 rewriter, loc, elementWise, newOperands,
659 {newWarpOp.getResult(operandIndex).getType()});
685 warpOp, [](
Operation *op) {
return isa<arith::ConstantOp>(op); });
689 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
695 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
698 Value distConstant = rewriter.
create<arith::ConstantOp>(loc, newAttr);
716 if (originalShape == distributedShape) {
717 delinearizedIds.clear();
722 for (
auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
723 if (large % small != 0)
725 sizes.push_back(large / small);
727 if (std::accumulate(sizes.begin(), sizes.end(), 1,
728 std::multiplies<int64_t>()) != warpSize)
734 int64_t usedThreads = 1;
736 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
737 delinearizedIds.assign(sizes.size(), zero);
739 for (
int i = sizes.size() - 1; i >= 0; --i) {
740 usedThreads *= sizes[i];
741 if (usedThreads == warpSize) {
744 delinearizedIds[i] = laneId;
750 builder, loc, s0.
floorDiv(usedThreads), {laneId});
778 warpOp, [](
Operation *op) {
return isa<vector::TransferReadOp>(op); });
783 if (!read.getResult().hasOneUse())
786 Value distributedVal = warpOp.getResult(operandIndex);
789 read.getIndices().end());
790 auto sequentialType = cast<VectorType>(read.getResult().getType());
791 auto distributedType = cast<VectorType>(distributedVal.
getType());
800 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
801 distributedType.getShape(), warpOp.getWarpSize(),
802 warpOp.getLaneid(), delinearizedIds))
804 read,
"cannot delinearize lane ID for distribution");
805 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
809 bindDims(read.getContext(), d0, d1);
814 unsigned vectorPos = std::get<1>(it).cast<
AffineDimExpr>().getPosition();
815 int64_t scale = distributedType.getDimSize(vectorPos);
817 rewriter, read.getLoc(), d0 + scale * d1,
818 {indices[indexPos], delinearizedIds[vectorPos]});
820 auto newRead = rewriter.
create<vector::TransferReadOp>(
821 read.getLoc(), distributedVal.
getType(), read.getSource(), indices,
822 read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
823 read.getInBoundsAttr());
840 if (!llvm::all_of(newRead->getOperands(), [&](
Value value) {
841 return warpOp.isDefinedOutsideOfRegion(value);
857 newResultTypes.reserve(warpOp->getNumResults());
859 newYieldValues.reserve(warpOp->getNumResults());
862 auto yield = cast<vector::YieldOp>(
863 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
873 for (
OpResult result : warpOp.getResults()) {
874 Value yieldOperand = yield.getOperand(result.getResultNumber());
875 auto it = dedupYieldOperandPositionMap.insert(
876 std::make_pair(yieldOperand, newResultTypes.size()));
877 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
878 if (result.use_empty() || !it.second)
880 newResultTypes.push_back(result.getType());
881 newYieldValues.push_back(yieldOperand);
884 if (yield.getNumOperands() == newYieldValues.size())
888 rewriter, warpOp, newYieldValues, newResultTypes);
891 newValues.reserve(warpOp->getNumResults());
892 for (
OpResult result : warpOp.getResults()) {
893 if (result.use_empty())
894 newValues.push_back(
Value());
897 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
906 struct WarpOpForwardOperand :
public OpRewritePattern<WarpExecuteOnLane0Op> {
912 auto yield = cast<vector::YieldOp>(
913 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
915 unsigned resultIndex;
916 for (
OpOperand &operand : yield->getOpOperands()) {
925 valForwarded = operand.
get();
929 auto arg = dyn_cast<BlockArgument>(operand.
get());
930 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
932 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
935 valForwarded = warpOperand;
951 warpOp, [](
Operation *op) {
return isa<vector::BroadcastOp>(op); });
956 Location loc = broadcastOp.getLoc();
958 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
959 Value broadcastSrc = broadcastOp.getSource();
971 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
973 Value broadcasted = rewriter.
create<vector::BroadcastOp>(
974 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
988 warpOp, [](
Operation *op) {
return isa<vector::ShapeCastOp>(op); });
994 auto castDistributedType =
995 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
996 VectorType castOriginalType = oldCastOp.getSourceVectorType();
997 VectorType castResultType = castDistributedType;
1001 unsigned castDistributedRank = castDistributedType.getRank();
1002 unsigned castOriginalRank = castOriginalType.getRank();
1003 if (castDistributedRank < castOriginalRank) {
1005 llvm::append_range(shape, castDistributedType.getShape());
1006 castDistributedType =
1012 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1015 Value newCast = rewriter.
create<vector::ShapeCastOp>(
1016 oldCastOp.getLoc(), castResultType,
1017 newWarpOp->getResult(newRetIndices[0]));
1030 warpOp, [](
Operation *op) {
return isa<vector::ExtractOp>(op); });
1035 VectorType extractSrcType = extractOp.getSourceVectorType();
1039 assert(extractSrcType.getRank() > 0 &&
1040 "vector.extract does not support rank 0 sources");
1044 if (extractOp.getNumIndices() == 0)
1048 if (extractSrcType.getRank() == 1) {
1049 if (extractOp.hasDynamicPosition())
1053 assert(extractOp.getNumIndices() == 1 &&
"expected 1 index");
1054 int64_t pos = extractOp.getStaticPosition()[0];
1057 extractOp, extractOp.getVector(),
1058 rewriter.
create<arith::ConstantIndexOp>(loc, pos));
1064 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1072 rewriter, warpOp, {extractOp.getVector()},
1073 {extractOp.getSourceVectorType()}, newRetIndices);
1075 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1077 Value newExtract = rewriter.
create<vector::ExtractOp>(
1078 loc, distributedVec, extractOp.getMixedPosition());
1085 auto distributedType =
1086 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1087 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1088 int64_t distributedDim = -1;
1089 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1090 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1093 assert(distributedDim == -1 &&
"found multiple distributed dims");
1097 assert(distributedDim != -1 &&
"could not find distributed dimension");
1098 (void)distributedDim;
1102 extractSrcType.getShape().end());
1103 for (
int i = 0; i < distributedType.getRank(); ++i)
1104 newDistributedShape[i + extractOp.getNumIndices()] =
1105 distributedType.getDimSize(i);
1106 auto newDistributedType =
1107 VectorType::get(newDistributedShape, distributedType.getElementType());
1110 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1113 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1115 Value newExtract = rewriter.
create<vector::ExtractOp>(
1116 loc, distributedVec, extractOp.getMixedPosition());
1125 struct WarpOpExtractElement :
public OpRewritePattern<WarpExecuteOnLane0Op> {
1126 WarpOpExtractElement(
MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1129 warpShuffleFromIdxFn(std::move(fn)) {}
1133 return isa<vector::ExtractElementOp>(op);
1139 VectorType extractSrcType = extractOp.getSourceVectorType();
1140 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1141 Type elType = extractSrcType.getElementType();
1142 VectorType distributedVecType;
1143 if (!is0dOrVec1Extract) {
1144 assert(extractSrcType.getRank() == 1 &&
1145 "expected that extractelement src rank is 0 or 1");
1146 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1148 int64_t elementsPerLane =
1149 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1152 distributedVecType = extractSrcType;
1158 rewriter, warpOp, {extractOp.getVector()}, {distributedVecType},
1161 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1165 if (is0dOrVec1Extract) {
1167 if (extractSrcType.getRank() == 1) {
1168 newExtract = rewriter.
create<vector::ExtractElementOp>(
1169 loc, distributedVec,
1170 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1174 rewriter.
create<vector::ExtractElementOp>(loc, distributedVec);
1183 int64_t elementsPerLane = distributedVecType.getShape()[0];
1186 Value broadcastFromTid = rewriter.
create<affine::AffineApplyOp>(
1187 loc, sym0.
ceilDiv(elementsPerLane), extractOp.getPosition());
1190 elementsPerLane == 1
1191 ? rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult()
1193 .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1194 extractOp.getPosition())
1197 rewriter.
create<vector::ExtractElementOp>(loc, distributedVec, pos);
1200 Value shuffled = warpShuffleFromIdxFn(
1201 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1207 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1210 struct WarpOpInsertElement :
public OpRewritePattern<WarpExecuteOnLane0Op> {
1216 warpOp, [](
Operation *op) {
return isa<vector::InsertElementOp>(op); });
1221 VectorType vecType = insertOp.getDestVectorType();
1222 VectorType distrType =
1223 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1224 bool hasPos =
static_cast<bool>(insertOp.getPosition());
1228 insertOp.getSource()};
1230 insertOp.getSource().getType()};
1232 additionalResults.push_back(insertOp.getPosition());
1233 additionalResultTypes.push_back(insertOp.getPosition().getType());
1238 rewriter, warpOp, additionalResults, additionalResultTypes,
1241 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1242 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1243 Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) :
Value();
1246 if (vecType == distrType) {
1248 Value newInsert = rewriter.
create<vector::InsertElementOp>(
1249 loc, newSource, distributedVec, newPos);
1256 int64_t elementsPerLane = distrType.getShape()[0];
1259 Value insertingLane = rewriter.
create<affine::AffineApplyOp>(
1260 loc, sym0.
ceilDiv(elementsPerLane), newPos);
1263 elementsPerLane == 1
1264 ? rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult()
1266 .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1269 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1270 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1274 loc, isInsertingLane,
1277 Value newInsert = builder.create<vector::InsertElementOp>(
1278 loc, newSource, distributedVec, pos);
1279 builder.create<scf::YieldOp>(loc, newInsert);
1283 builder.create<scf::YieldOp>(loc, distributedVec);
1297 warpOp, [](
Operation *op) {
return isa<vector::InsertOp>(op); });
1305 if (insertOp.getNumIndices() == 0)
1309 if (insertOp.getDestVectorType().getRank() == 1) {
1310 if (insertOp.hasDynamicPosition())
1314 assert(insertOp.getNumIndices() == 1 &&
"expected 1 index");
1315 int64_t pos = insertOp.getStaticPosition()[0];
1318 insertOp, insertOp.getSource(), insertOp.getDest(),
1319 rewriter.
create<arith::ConstantIndexOp>(loc, pos));
1323 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1328 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1329 {insertOp.getSourceType(), insertOp.getDestVectorType()},
1332 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1333 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1334 Value newResult = rewriter.
create<vector::InsertOp>(
1335 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1342 auto distrDestType =
1343 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1344 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1345 int64_t distrDestDim = -1;
1346 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1347 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1350 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1354 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1357 VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
1359 srcVecType.getShape().end());
1366 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1367 if (distrSrcDim >= 0)
1368 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1375 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1376 {distrSrcType, distrDestType}, newRetIndices);
1378 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1379 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1383 if (distrSrcDim >= 0) {
1385 newResult = rewriter.
create<vector::InsertOp>(
1386 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1389 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1393 Value insertingLane = rewriter.
create<arith::ConstantIndexOp>(
1394 loc, newPos[distrDestDim] / elementsPerLane);
1395 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1396 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1398 newPos[distrDestDim] %= elementsPerLane;
1400 Value newInsert = builder.
create<vector::InsertOp>(
1401 loc, distributedSrc, distributedDest, newPos);
1402 builder.
create<scf::YieldOp>(loc, newInsert);
1405 builder.
create<scf::YieldOp>(loc, distributedDest);
1407 newResult = rewriter
1408 .
create<scf::IfOp>(loc, isInsertingLane,
1410 nonInsertingBuilder)
1455 distributionMapFn(std::move(fn)) {}
1459 auto yield = cast<vector::YieldOp>(
1460 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1462 Operation *lastNode = yield->getPrevNode();
1463 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1469 llvm::SmallSetVector<Value, 32> escapingValues;
1473 forOp.getBodyRegion(), [&](
OpOperand *operand) {
1474 Operation *parent = operand->get().getParentRegion()->getParentOp();
1475 if (warpOp->isAncestor(parent)) {
1476 if (!escapingValues.insert(operand->get()))
1478 Type distType = operand->get().getType();
1479 if (auto vecType = dyn_cast<VectorType>(distType)) {
1480 AffineMap map = distributionMapFn(operand->get());
1481 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1483 inputTypes.push_back(operand->get().getType());
1484 distTypes.push_back(distType);
1490 rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1492 yield = cast<vector::YieldOp>(
1493 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1498 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
1499 if (yieldOperand.get().
getDefiningOp() != forOp.getOperation())
1501 auto forResult = cast<OpResult>(yieldOperand.get());
1502 newOperands.push_back(
1503 newWarpOp.getResult(yieldOperand.getOperandNumber()));
1504 yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1505 resultIdx.push_back(yieldOperand.getOperandNumber());
1513 auto newForOp = rewriter.
create<scf::ForOp>(
1514 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1515 forOp.getStep(), newOperands);
1519 newForOp.getRegionIterArgs().end());
1521 forOp.getResultTypes().end());
1522 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1524 warpInput.push_back(newWarpOp.getResult(retIdx));
1525 argIndexMapping[escapingValues[i]] = warpInputType.size();
1526 warpInputType.push_back(inputTypes[i]);
1528 auto innerWarp = rewriter.
create<WarpExecuteOnLane0Op>(
1529 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1530 newWarpOp.getWarpSize(), warpInput, warpInputType);
1533 argMapping.push_back(newForOp.getInductionVar());
1534 for (
Value args : innerWarp.getBody()->getArguments()) {
1535 argMapping.push_back(args);
1537 argMapping.resize(forOp.getBody()->getNumArguments());
1539 for (
Value operand : forOp.getBody()->getTerminator()->getOperands())
1540 yieldOperands.push_back(operand);
1541 rewriter.
eraseOp(forOp.getBody()->getTerminator());
1542 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1544 rewriter.
create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
1546 if (!innerWarp.getResults().empty())
1547 rewriter.
create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1552 newForOp.getResult(res.index()));
1553 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1557 auto it = argIndexMapping.find(operand.
get());
1558 if (it == argIndexMapping.end())
1560 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
1591 DistributedReductionFn distributedReductionFn,
1594 distributedReductionFn(std::move(distributedReductionFn)) {}
1599 warpOp, [](
Operation *op) {
return isa<vector::ReductionOp>(op); });
1605 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1607 if (vectorType.getRank() != 1)
1609 warpOp,
"Only rank 1 reductions can be distributed.");
1611 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1613 warpOp,
"Reduction vector dimension must match was size.");
1614 if (!reductionOp.getType().isIntOrFloat())
1616 warpOp,
"Reduction distribution currently only supports floats and "
1619 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1625 if (reductionOp.getAcc()) {
1626 yieldValues.push_back(reductionOp.getAcc());
1627 retTypes.push_back(reductionOp.getAcc().getType());
1631 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1635 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1638 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1639 reductionOp.getKind(), newWarpOp.getWarpSize());
1640 if (reductionOp.getAcc()) {
1642 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1643 newWarpOp.getResult(newRetIndices[1]));
1650 DistributedReductionFn distributedReductionFn;
1661 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1664 patterns.
add<WarpOpTransferWrite>(patterns.
getContext(), distributionMapFn,
1668 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1670 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit) {
1671 patterns.
add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
1672 WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
1673 WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement,
1674 WarpOpInsert>(patterns.
getContext(), benefit);
1676 warpShuffleFromIdxFn, benefit);
1677 patterns.
add<WarpOpScfForOp>(patterns.
getContext(), distributionMapFn,
1681 void mlir::vector::populateDistributeReduction(
1683 const DistributedReductionFn &distributedReductionFn,
1685 patterns.
add<WarpOpReduction>(patterns.
getContext(), distributedReductionFn,
1689 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1690 Block *body = warpOp.getBody();
1693 llvm::SmallSetVector<Operation *, 8> opsToMove;
1696 auto isDefinedOutsideOfBody = [&](
Value value) {
1698 return (definingOp && opsToMove.count(definingOp)) ||
1699 warpOp.isDefinedOutsideOfRegion(value);
1705 bool hasVectorResult = llvm::any_of(op.
getResults(), [](
Value result) {
1706 return isa<VectorType>(result.getType());
1708 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
1709 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, const std::function< bool(Operation *)> &fn)
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes)
Helper to create a new WarpExecuteOnLane0Op with different signature.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, llvm::SmallVector< size_t > &indices)
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, Value mask=Value())
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.