18 #include "llvm/ADT/SetVector.h"
36 VectorType distributedType) {
42 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
43 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
47 distributedType.getContext());
48 assert(map.getNumResults() <= 1 &&
49 "only support distribution along one dimension for now.");
60 struct DistributedLoadStoreHelper {
61 DistributedLoadStoreHelper(
Value sequentialVal,
Value distributedVal,
63 : sequentialVal(sequentialVal), distributedVal(distributedVal),
64 laneId(laneId), zero(zero) {
66 distributedVectorType = distributedVal.
getType().
dyn_cast<VectorType>();
67 if (sequentialVectorType && distributedVectorType)
73 int64_t distributedSize = distributedVectorType.getDimSize(index);
75 return b.
createOrFold<AffineApplyOp>(loc, tid * distributedSize,
88 assert((val == distributedVal || val == sequentialVal) &&
89 "Must store either the preregistered distributed or the "
90 "preregistered sequential value.");
93 return b.
create<memref::StoreOp>(loc, val, buffer, zero);
97 int64_t rank = sequentialVectorType.getRank();
99 if (val == distributedVal) {
100 for (
auto dimExpr : distributionMap.getResults()) {
102 indices[index] = buildDistributedOffset(b, loc, index);
106 return b.
create<vector::TransferWriteOp>(
107 loc, val, buffer, indices,
134 if (!type.
isa<VectorType>())
135 return b.
create<memref::LoadOp>(loc, buffer, zero);
140 assert((type == distributedVectorType || type == sequentialVectorType) &&
141 "Must store either the preregistered distributed or the "
142 "preregistered sequential type.");
144 if (type == distributedVectorType) {
145 for (
auto dimExpr : distributionMap.getResults()) {
147 indices[index] = buildDistributedOffset(b, loc, index);
151 return b.
create<vector::TransferReadOp>(
152 loc, type.
cast<VectorType>(), buffer, indices,
156 Value sequentialVal, distributedVal, laneId, zero;
157 VectorType sequentialVectorType, distributedVectorType;
170 auto newWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
171 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
172 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
174 Region &opBody = warpOp.getBodyRegion();
175 Region &newOpBody = newWarpOp.getBodyRegion();
179 assert(newWarpOp.getWarpRegion().hasOneBlock() &&
180 "expected WarpOp with single block");
183 cast<vector::YieldOp>(newOpBody.
getBlocks().begin()->getTerminator());
186 yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
197 warpOp.getResultTypes().end());
198 auto yield = cast<vector::YieldOp>(
199 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
200 llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
201 yield.getOperands().end());
202 for (
auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
203 if (yieldValues.insert(std::get<0>(newRet))) {
204 types.push_back(std::get<1>(newRet));
205 indices.push_back(yieldValues.size() - 1);
208 for (
auto [idx, yieldOperand] :
210 if (yieldOperand == std::get<0>(newRet)) {
211 indices.push_back(idx);
217 yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
219 rewriter, warpOp, yieldValues.getArrayRef(), types);
221 newWarpOp.getResults().take_front(warpOp.getNumResults()));
228 return llvm::all_of(op->
getOperands(), definedOutside) &&
235 const std::function<
bool(
Operation *)> &fn) {
236 auto yield = cast<vector::YieldOp>(
237 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
238 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
239 Value yieldValues = yieldOperand.get();
241 if (definedOp && fn(definedOp)) {
242 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
243 return &yieldOperand;
257 return rewriter.
create(res);
290 struct WarpOpToScfIfPattern :
public OpRewritePattern<WarpExecuteOnLane0Op> {
299 assert(warpOp.getBodyRegion().hasOneBlock() &&
300 "expected WarpOp with single block");
301 Block *warpOpBody = &warpOp.getBodyRegion().
front();
309 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
311 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
312 auto ifOp = rewriter.
create<scf::IfOp>(loc, isLane0,
314 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
321 Value distributedVal = it.value();
322 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
323 warpOp.getLaneid(), c0);
327 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
330 helper.buildStore(rewriter, loc, distributedVal, buffer);
333 bbArgReplacements.push_back(
334 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
338 if (!warpOp.getArgs().empty()) {
340 options.warpSyncronizationFn(loc, rewriter, warpOp);
344 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
351 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
352 Location yieldLoc = yieldOp.getLoc();
354 Value sequentialVal = it.value();
355 Value distributedVal = warpOp->getResult(it.index());
356 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
357 warpOp.getLaneid(), c0);
361 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
367 helper.buildStore(rewriter, loc, sequentialVal, buffer);
378 replacements.push_back(
379 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
383 if (!yieldOp.getOperands().empty()) {
385 options.warpSyncronizationFn(loc, rewriter, warpOp);
391 rewriter.
create<scf::YieldOp>(yieldLoc);
394 rewriter.
replaceOp(warpOp, replacements);
407 static vector::TransferWriteOp cloneWriteOp(
RewriterBase &rewriter,
408 WarpExecuteOnLane0Op warpOp,
409 vector::TransferWriteOp writeOp,
410 VectorType targetType) {
411 assert(writeOp->getParentOp() == warpOp &&
412 "write must be nested immediately under warp");
416 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
420 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
422 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
434 static VectorType getDistributedType(VectorType originalType,
AffineMap map,
439 originalType.getShape().end());
442 if (targetShape[position] % warpSize != 0)
444 targetShape[position] = targetShape[position] / warpSize;
446 VectorType targetType =
447 VectorType::get(targetShape, originalType.getElementType());
468 struct WarpOpTransferWrite :
public OpRewritePattern<vector::TransferWriteOp> {
472 distributionMapFn(std::move(fn)) {}
477 vector::TransferWriteOp writeOp,
478 WarpExecuteOnLane0Op warpOp)
const {
479 VectorType writtenVectorType = writeOp.getVectorType();
483 if (writtenVectorType.getRank() == 0)
487 AffineMap map = distributionMapFn(writeOp.getVector());
488 VectorType targetType =
489 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
495 vector::TransferWriteOp newWriteOp =
496 cloneWriteOp(rewriter, warpOp, writeOp, targetType);
500 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
505 newWriteOp.getIndices().end());
508 bindDims(newWarpOp.getContext(), d0, d1);
513 unsigned vectorPos = std::get<1>(it).cast<
AffineDimExpr>().getPosition();
518 {indices[indexPos], newWarpOp.getLaneid()});
520 newWriteOp.getIndicesMutable().assign(indices);
527 vector::TransferWriteOp writeOp,
528 WarpExecuteOnLane0Op warpOp)
const {
530 VectorType vecType = writeOp.getVectorType();
534 if (vecType.getNumElements() != 1)
538 if (llvm::all_of(warpOp.getOps(), [](
Operation &op) {
539 return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
547 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
551 auto secondWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
552 loc,
TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
553 Block &body = secondWarpOp.getBodyRegion().
front();
556 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
557 newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
559 rewriter.
create<vector::YieldOp>(newWarpOp.getLoc());
563 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
566 if (writeOp.getMask())
569 auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
574 Operation *nextOp = writeOp.getOperation();
575 while ((nextOp = nextOp->getNextNode()))
579 if (!llvm::all_of(writeOp->getOperands(), [&](
Value value) {
580 return writeOp.getVector() == value ||
581 warpOp.isDefinedOutsideOfRegion(value);
585 if (
succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
588 if (
succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
627 Value distributedVal = warpOp.getResult(operandIndex);
635 auto operandType = operand.get().getType().
cast<VectorType>();
637 VectorType::get(vecType.getShape(), operandType.getElementType());
639 auto operandType = operand.get().getType();
640 assert(!operandType.isa<VectorType>() &&
641 "unexpected yield of vector from op with scalar result type");
642 targetType = operandType;
644 retTypes.push_back(targetType);
645 yieldValues.push_back(operand.get());
649 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
653 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
654 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
659 rewriter, loc, elementWise, newOperands,
660 {newWarpOp.getResult(operandIndex).getType()});
686 warpOp, [](
Operation *op) {
return isa<arith::ConstantOp>(op); });
696 warpOp.getResult(operandIndex).getType(), scalarAttr);
699 Value distConstant = rewriter.
create<arith::ConstantOp>(loc, newAttr);
728 warpOp, [](
Operation *op) {
return isa<vector::TransferReadOp>(op); });
733 if (!read.getResult().hasOneUse())
736 Value distributedVal = warpOp.getResult(operandIndex);
739 read.getIndices().end());
740 auto sequentialType = read.getResult().getType().cast<VectorType>();
741 auto distributedType = distributedVal.
getType().
cast<VectorType>();
746 for (
auto it : llvm::zip(indexMap.
getResults(), map.getResults())) {
748 bindDims(read.getContext(), d0, d1);
753 unsigned vectorPos = std::get<1>(it).cast<
AffineDimExpr>().getPosition();
755 distributedVal.
getType().
cast<VectorType>().getDimSize(vectorPos);
758 {indices[indexPos], warpOp.getLaneid()});
760 Value newRead = rewriter.
create<vector::TransferReadOp>(
761 read.getLoc(), distributedVal.
getType(), read.getSource(), indices,
762 read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
763 read.getInBoundsAttr());
776 newResultTypes.reserve(warpOp->getNumResults());
778 newYieldValues.reserve(warpOp->getNumResults());
781 auto yield = cast<vector::YieldOp>(
782 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
792 for (
OpResult result : warpOp.getResults()) {
793 Value yieldOperand = yield.getOperand(result.getResultNumber());
794 auto it = dedupYieldOperandPositionMap.insert(
795 std::make_pair(yieldOperand, newResultTypes.size()));
796 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
797 if (result.use_empty() || !it.second)
799 newResultTypes.push_back(result.getType());
800 newYieldValues.push_back(yieldOperand);
803 if (yield.getNumOperands() == newYieldValues.size())
807 rewriter, warpOp, newYieldValues, newResultTypes);
810 newValues.reserve(warpOp->getNumResults());
811 for (
OpResult result : warpOp.getResults()) {
812 if (result.use_empty())
813 newValues.push_back(
Value());
816 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
825 struct WarpOpForwardOperand :
public OpRewritePattern<WarpExecuteOnLane0Op> {
831 auto yield = cast<vector::YieldOp>(
832 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
834 unsigned resultIndex;
835 for (
OpOperand &operand : yield->getOpOperands()) {
844 valForwarded = operand.
get();
849 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
851 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
854 valForwarded = warpOperand;
870 warpOp, [](
Operation *op) {
return isa<vector::BroadcastOp>(op); });
875 Location loc = broadcastOp.getLoc();
877 warpOp->getResultTypes()[operandNumber].
cast<VectorType>();
880 rewriter, warpOp, {broadcastOp.getSource()},
881 {broadcastOp.getSource().getType()}, newRetIndices);
883 Value broadcasted = rewriter.
create<vector::BroadcastOp>(
884 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
898 warpOp, [](
Operation *op) {
return isa<vector::ExtractOp>(op); });
903 VectorType extractSrcType = extractOp.getSourceVectorType();
907 assert(extractSrcType.getRank() > 0 &&
908 "vector.extract does not support rank 0 sources");
911 if (extractOp.getPosition().empty())
915 if (extractSrcType.getRank() == 1) {
916 assert(extractOp.getPosition().size() == 1 &&
"expected 1 index");
917 int64_t pos = extractOp.getPosition()[0].cast<IntegerAttr>().getInt();
920 extractOp, extractOp.getVector(),
921 rewriter.
create<arith::ConstantIndexOp>(loc, pos));
927 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
935 rewriter, warpOp, {extractOp.getVector()},
936 {extractOp.getSourceVectorType()}, newRetIndices);
938 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
940 Value newExtract = rewriter.
create<vector::ExtractOp>(
941 loc, distributedVec, extractOp.getPosition());
948 auto distributedType =
949 warpOp.getResult(operandNumber).getType().cast<VectorType>();
951 int64_t distributedDim = -1;
952 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
953 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
956 assert(distributedDim == -1 &&
"found multiple distributed dims");
960 assert(distributedDim != -1 &&
"could not find distributed dimension");
961 (void)distributedDim;
965 extractSrcType.getShape().end());
966 for (
int i = 0; i < distributedType.getRank(); ++i)
967 newDistributedShape[i + extractOp.getPosition().size()] =
968 distributedType.getDimSize(i);
969 auto newDistributedType =
970 VectorType::get(newDistributedShape, distributedType.getElementType());
973 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
976 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
978 Value newExtract = rewriter.
create<vector::ExtractOp>(
979 loc, distributedVec, extractOp.getPosition());
988 struct WarpOpExtractElement :
public OpRewritePattern<WarpExecuteOnLane0Op> {
989 WarpOpExtractElement(
MLIRContext *ctx, WarpShuffleFromIdxFn fn,
992 warpShuffleFromIdxFn(std::move(fn)) {}
996 return isa<vector::ExtractElementOp>(op);
1002 VectorType extractSrcType = extractOp.getSourceVectorType();
1003 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1004 Type elType = extractSrcType.getElementType();
1005 VectorType distributedVecType;
1006 if (!is0dOrVec1Extract) {
1007 assert(extractSrcType.getRank() == 1 &&
1008 "expected that extractelement src rank is 0 or 1");
1009 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1011 int64_t elementsPerLane =
1012 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1013 distributedVecType = VectorType::get({elementsPerLane}, elType);
1015 distributedVecType = extractSrcType;
1021 rewriter, warpOp, {extractOp.getVector()}, {distributedVecType},
1024 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1028 if (is0dOrVec1Extract) {
1030 if (extractSrcType.getRank() == 1) {
1031 newExtract = rewriter.
create<vector::ExtractElementOp>(
1032 loc, distributedVec,
1033 rewriter.
create<arith::ConstantIndexOp>(loc, 0));
1037 rewriter.
create<vector::ExtractElementOp>(loc, distributedVec);
1046 int64_t elementsPerLane = distributedVecType.getShape()[0];
1049 Value broadcastFromTid = rewriter.
create<AffineApplyOp>(
1050 loc, sym0.
ceilDiv(elementsPerLane), extractOp.getPosition());
1053 elementsPerLane == 1
1054 ? rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult()
1056 .create<AffineApplyOp>(loc, sym0 % elementsPerLane,
1057 extractOp.getPosition())
1060 rewriter.
create<vector::ExtractElementOp>(loc, distributedVec, pos);
1063 Value shuffled = warpShuffleFromIdxFn(
1064 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1070 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1073 struct WarpOpInsertElement :
public OpRewritePattern<WarpExecuteOnLane0Op> {
1079 warpOp, [](
Operation *op) {
return isa<vector::InsertElementOp>(op); });
1084 VectorType vecType = insertOp.getDestVectorType();
1085 VectorType distrType =
1087 bool hasPos =
static_cast<bool>(insertOp.getPosition());
1091 insertOp.getSource()};
1093 insertOp.getSource().getType()};
1095 additionalResults.push_back(insertOp.getPosition());
1096 additionalResultTypes.push_back(insertOp.getPosition().getType());
1101 rewriter, warpOp, additionalResults, additionalResultTypes,
1104 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1105 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1106 Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) :
Value();
1109 if (vecType == distrType) {
1111 Value newInsert = rewriter.
create<vector::InsertElementOp>(
1112 loc, newSource, distributedVec, newPos);
1119 int64_t elementsPerLane = distrType.getShape()[0];
1122 Value insertingLane = rewriter.
create<AffineApplyOp>(
1123 loc, sym0.
ceilDiv(elementsPerLane), newPos);
1126 elementsPerLane == 1
1127 ? rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult()
1129 .create<AffineApplyOp>(loc, sym0 % elementsPerLane, newPos)
1131 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1132 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1136 loc, isInsertingLane,
1139 Value newInsert = builder.create<vector::InsertElementOp>(
1140 loc, newSource, distributedVec, pos);
1141 builder.create<scf::YieldOp>(loc, newInsert);
1145 builder.create<scf::YieldOp>(loc, distributedVec);
1159 warpOp, [](
Operation *op) {
return isa<vector::InsertOp>(op); });
1167 if (insertOp.getPosition().empty())
1171 if (insertOp.getDestVectorType().getRank() == 1) {
1172 assert(insertOp.getPosition().size() == 1 &&
"expected 1 index");
1173 int64_t pos = insertOp.getPosition()[0].cast<IntegerAttr>().getInt();
1176 insertOp, insertOp.getSource(), insertOp.getDest(),
1177 rewriter.
create<arith::ConstantIndexOp>(loc, pos));
1181 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1186 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1187 {insertOp.getSourceType(), insertOp.getDestVectorType()},
1190 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1191 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1192 Value newResult = rewriter.
create<vector::InsertOp>(
1193 loc, distributedSrc, distributedDest, insertOp.getPosition());
1200 auto distrDestType =
1201 warpOp.getResult(operandNumber).getType().cast<VectorType>();
1203 int64_t distrDestDim = -1;
1204 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1205 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1208 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1212 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1215 VectorType srcVecType = insertOp.getSourceType().cast<VectorType>();
1217 srcVecType.getShape().end());
1224 int64_t distrSrcDim = distrDestDim - insertOp.getPosition().size();
1225 if (distrSrcDim >= 0)
1226 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1228 VectorType::get(distrSrcShape, distrDestType.getElementType());
1233 rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1234 {distrSrcType, distrDestType}, newRetIndices);
1236 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1237 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1241 if (distrSrcDim >= 0) {
1243 newResult = rewriter.
create<vector::InsertOp>(
1244 loc, distributedSrc, distributedDest, insertOp.getPosition());
1247 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1249 llvm::map_range(insertOp.getPosition(), [](
Attribute attr) {
1250 return attr.cast<IntegerAttr>().getInt();
1253 Value insertingLane = rewriter.
create<arith::ConstantIndexOp>(
1254 loc, newPos[distrDestDim] / elementsPerLane);
1255 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1256 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1258 newPos[distrDestDim] %= elementsPerLane;
1260 Value newInsert = builder.
create<vector::InsertOp>(
1261 loc, distributedSrc, distributedDest, newPos);
1262 builder.
create<scf::YieldOp>(loc, newInsert);
1265 builder.
create<scf::YieldOp>(loc, distributedDest);
1267 newResult = rewriter
1268 .
create<scf::IfOp>(loc, isInsertingLane,
1270 nonInsertingBuilder)
1315 distributionMapFn(std::move(fn)) {}
1319 auto yield = cast<vector::YieldOp>(
1320 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1322 Operation *lastNode = yield->getPrevNode();
1323 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1329 llvm::SmallSetVector<Value, 32> escapingValues;
1333 forOp.getBodyRegion(), [&](
OpOperand *operand) {
1334 Operation *parent = operand->get().getParentRegion()->getParentOp();
1335 if (warpOp->isAncestor(parent)) {
1336 if (!escapingValues.insert(operand->get()))
1338 Type distType = operand->get().getType();
1339 if (auto vecType = distType.cast<VectorType>()) {
1340 AffineMap map = distributionMapFn(operand->get());
1341 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1343 inputTypes.push_back(operand->get().getType());
1344 distTypes.push_back(distType);
1350 rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1352 yield = cast<vector::YieldOp>(
1353 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1358 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
1359 if (yieldOperand.get().
getDefiningOp() != forOp.getOperation())
1362 newOperands.push_back(
1363 newWarpOp.getResult(yieldOperand.getOperandNumber()));
1364 yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
1365 resultIdx.push_back(yieldOperand.getOperandNumber());
1373 auto newForOp = rewriter.
create<scf::ForOp>(
1374 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1375 forOp.getStep(), newOperands);
1379 newForOp.getRegionIterArgs().end());
1381 forOp.getResultTypes().end());
1382 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1384 warpInput.push_back(newWarpOp.getResult(retIdx));
1385 argIndexMapping[escapingValues[i]] = warpInputType.size();
1386 warpInputType.push_back(inputTypes[i]);
1388 auto innerWarp = rewriter.
create<WarpExecuteOnLane0Op>(
1389 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1390 newWarpOp.getWarpSize(), warpInput, warpInputType);
1393 argMapping.push_back(newForOp.getInductionVar());
1394 for (
Value args : innerWarp.getBody()->getArguments()) {
1395 argMapping.push_back(args);
1397 argMapping.resize(forOp.getBody()->getNumArguments());
1399 for (
Value operand : forOp.getBody()->getTerminator()->getOperands())
1400 yieldOperands.push_back(operand);
1401 rewriter.
eraseOp(forOp.getBody()->getTerminator());
1402 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1404 rewriter.
create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
1406 if (!innerWarp.getResults().empty())
1407 rewriter.
create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1412 newForOp.getResult(res.index()));
1413 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1417 auto it = argIndexMapping.find(operand.
get());
1418 if (it == argIndexMapping.end())
1420 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
1451 DistributedReductionFn distributedReductionFn,
1454 distributedReductionFn(std::move(distributedReductionFn)) {}
1459 warpOp, [](
Operation *op) {
return isa<vector::ReductionOp>(op); });
1465 auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
1467 if (vectorType.getRank() != 1)
1469 warpOp,
"Only rank 1 reductions can be distributed.");
1471 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1473 warpOp,
"Reduction vector dimension must match was size.");
1474 if (!reductionOp.getType().isIntOrFloat())
1476 warpOp,
"Reduction distribution currently only supports floats and "
1479 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1484 VectorType::get({numElements}, reductionOp.getType())};
1485 if (reductionOp.getAcc()) {
1486 yieldValues.push_back(reductionOp.getAcc());
1487 retTypes.push_back(reductionOp.getAcc().getType());
1491 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1495 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1498 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1499 reductionOp.getKind(), newWarpOp.getWarpSize());
1500 if (reductionOp.getAcc()) {
1502 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1503 newWarpOp.getResult(newRetIndices[1]));
1510 DistributedReductionFn distributedReductionFn;
1521 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1524 patterns.
add<WarpOpTransferWrite>(patterns.
getContext(), distributionMapFn,
1528 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1530 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit) {
1531 patterns.
add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
1532 WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
1533 WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
1536 warpShuffleFromIdxFn, benefit);
1537 patterns.
add<WarpOpScfForOp>(patterns.
getContext(), distributionMapFn,
1541 void mlir::vector::populateDistributeReduction(
1543 const DistributedReductionFn &distributedReductionFn,
1545 patterns.
add<WarpOpReduction>(patterns.
getContext(), distributedReductionFn,
1549 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1550 Block *body = warpOp.getBody();
1553 llvm::SmallSetVector<Operation *, 8> opsToMove;
1556 auto isDefinedOutsideOfBody = [&](
Value value) {
1558 return (definingOp && opsToMove.count(definingOp)) ||
1559 warpOp.isDefinedOutsideOfRegion(value);
1565 bool hasVectorResult = llvm::any_of(op.
getResults(), [](
Value result) {
1566 return result.getType().isa<VectorType>();
1568 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
1569 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, const std::function< bool(Operation *)> &fn)
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes)
Helper to create a new WarpExecuteOnLane0Op with different signature.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, llvm::SmallVector< size_t > &indices)
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, Value mask=Value())
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.