20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/FormatVariadic.h"
40 VectorType distributedType) {
46 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
47 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
51 distributedType.getContext());
62 struct DistributedLoadStoreHelper {
63 DistributedLoadStoreHelper(
Value sequentialVal,
Value distributedVal,
65 : sequentialVal(sequentialVal), distributedVal(distributedVal),
66 laneId(laneId), zero(zero) {
67 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.
getType());
68 distributedVectorType = dyn_cast<VectorType>(distributedVal.
getType());
69 if (sequentialVectorType && distributedVectorType)
75 int64_t distributedSize = distributedVectorType.getDimSize(index);
77 return b.
createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
90 assert((val == distributedVal || val == sequentialVal) &&
91 "Must store either the preregistered distributed or the "
92 "preregistered sequential value.");
94 if (!isa<VectorType>(val.
getType()))
95 return b.
create<memref::StoreOp>(loc, val, buffer, zero);
99 int64_t rank = sequentialVectorType.getRank();
101 if (val == distributedVal) {
102 for (
auto dimExpr : distributionMap.getResults()) {
103 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
104 indices[index] = buildDistributedOffset(b, loc, index);
108 return b.
create<vector::TransferWriteOp>(
109 loc, val, buffer, indices,
136 if (!isa<VectorType>(type))
137 return b.
create<memref::LoadOp>(loc, buffer, zero);
142 assert((type == distributedVectorType || type == sequentialVectorType) &&
143 "Must store either the preregistered distributed or the "
144 "preregistered sequential type.");
146 if (type == distributedVectorType) {
147 for (
auto dimExpr : distributionMap.getResults()) {
148 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
149 indices[index] = buildDistributedOffset(b, loc, index);
153 return b.
create<vector::TransferReadOp>(
154 loc, cast<VectorType>(type), buffer, indices,
158 Value sequentialVal, distributedVal, laneId, zero;
159 VectorType sequentialVectorType, distributedVectorType;
173 return rewriter.
create(res);
212 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
214 assert(warpOp.getBodyRegion().hasOneBlock() &&
215 "expected WarpOp with single block");
216 Block *warpOpBody = &warpOp.getBodyRegion().
front();
224 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
226 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
227 auto ifOp = rewriter.
create<scf::IfOp>(loc, isLane0,
229 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
236 Value distributedVal = it.value();
237 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
238 warpOp.getLaneid(), c0);
242 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
245 helper.buildStore(rewriter, loc, distributedVal, buffer);
248 bbArgReplacements.push_back(
249 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
253 if (!warpOp.getArgs().empty()) {
255 options.warpSyncronizationFn(loc, rewriter, warpOp);
259 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
266 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
267 Location yieldLoc = yieldOp.getLoc();
269 Value sequentialVal = it.value();
270 Value distributedVal = warpOp->getResult(it.index());
271 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
272 warpOp.getLaneid(), c0);
276 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
282 helper.buildStore(rewriter, loc, sequentialVal, buffer);
293 replacements.push_back(
294 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
298 if (!yieldOp.getOperands().empty()) {
300 options.warpSyncronizationFn(loc, rewriter, warpOp);
306 rewriter.
create<scf::YieldOp>(yieldLoc);
309 rewriter.
replaceOp(warpOp, replacements);
326 static VectorType getDistributedType(VectorType originalType,
AffineMap map,
331 if (targetShape[position] % warpSize != 0) {
332 if (warpSize % targetShape[position] != 0) {
335 warpSize /= targetShape[position];
336 targetShape[position] = 1;
339 targetShape[position] = targetShape[position] / warpSize;
346 VectorType targetType =
374 maxNumElementsToExtract(maxNumElementsToExtract) {}
379 vector::TransferWriteOp writeOp,
380 WarpExecuteOnLane0Op warpOp)
const {
381 VectorType writtenVectorType = writeOp.getVectorType();
385 if (writtenVectorType.getRank() == 0)
389 AffineMap map = distributionMapFn(writeOp.getVector());
390 VectorType targetType =
391 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
397 if (writeOp.getMask()) {
404 if (!writeOp.getPermutationMap().isMinorIdentity())
407 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
412 vector::TransferWriteOp newWriteOp =
413 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
417 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
424 for (
auto [seqSize, distSize] :
425 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
426 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
427 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
431 delinearized = rewriter
432 .
create<mlir::affine::AffineDelinearizeIndexOp>(
433 newWarpOp.getLoc(), newWarpOp.getLaneid(),
439 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
445 newWriteOp.getIndices().end());
448 bindDims(newWarpOp.getContext(), d0, d1);
449 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
452 unsigned indexPos = indexExpr.getPosition();
453 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
454 Value laneId = delinearized[vectorPos];
458 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
460 newWriteOp.getIndicesMutable().assign(indices);
467 vector::TransferWriteOp writeOp,
468 WarpExecuteOnLane0Op warpOp)
const {
470 VectorType vecType = writeOp.getVectorType();
472 if (vecType.getNumElements() > maxNumElementsToExtract) {
476 "writes more elements ({0}) than allowed to extract ({1})",
477 vecType.getNumElements(), maxNumElementsToExtract));
481 if (llvm::all_of(warpOp.getOps(),
482 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
488 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
489 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
493 auto secondWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
494 loc,
TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
495 Block &body = secondWarpOp.getBodyRegion().
front();
498 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
499 newWriteOp.getValueToStoreMutable().assign(
500 newWarpOp.getResult(newRetIndices[0]));
502 rewriter.
create<gpu::YieldOp>(newWarpOp.getLoc());
506 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
508 auto yield = cast<gpu::YieldOp>(
509 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
510 Operation *lastNode = yield->getPrevNode();
511 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
515 Value maybeMask = writeOp.getMask();
516 if (!llvm::all_of(writeOp->getOperands(), [&](
Value value) {
517 return writeOp.getVector() == value ||
518 (maybeMask && maybeMask == value) ||
519 warpOp.isDefinedOutsideOfRegion(value);
523 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
527 if (writeOp.getMask())
530 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
540 vector::TransferWriteOp cloneWriteOp(
RewriterBase &rewriter,
541 WarpExecuteOnLane0Op warpOp,
542 vector::TransferWriteOp writeOp,
543 VectorType targetType,
544 VectorType maybeMaskType)
const {
545 assert(writeOp->getParentOp() == warpOp &&
546 "write must be nested immediately under warp");
549 WarpExecuteOnLane0Op newWarpOp;
551 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
552 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
553 TypeRange{targetType, maybeMaskType}, newRetIndices);
555 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
556 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
561 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
563 newWriteOp.getValueToStoreMutable().assign(
564 newWarpOp.getResult(newRetIndices[0]));
566 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
571 unsigned maxNumElementsToExtract = 1;
594 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
604 Value distributedVal = warpOp.getResult(operandIndex);
610 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
612 auto operandType = cast<VectorType>(operand.get().getType());
616 auto operandType = operand.get().getType();
617 assert(!isa<VectorType>(operandType) &&
618 "unexpected yield of vector from op with scalar result type");
619 targetType = operandType;
621 retTypes.push_back(targetType);
622 yieldValues.push_back(operand.get());
625 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
626 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
630 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
631 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
636 rewriter, loc, elementWise, newOperands,
637 {newWarpOp.getResult(operandIndex).
getType()});
660 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
663 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
667 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
676 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
679 Value distConstant = rewriter.
create<arith::ConstantOp>(loc, newAttr);
706 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
713 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
717 warpOp,
"warp result is not a vector.transfer_read op");
721 if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
723 read,
"source must be defined outside of the region");
726 Value distributedVal = warpOp.getResult(operandIndex);
729 read.getIndices().end());
730 auto sequentialType = cast<VectorType>(read.getResult().getType());
731 auto distributedType = cast<VectorType>(distributedVal.
getType());
738 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
739 distributedType.getShape(), warpOp.getWarpSize(),
740 warpOp.getLaneid(), delinearizedIds)) {
742 read,
"cannot delinearize lane ID for distribution");
744 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
751 additionalResults.push_back(read.getPadding());
752 additionalResultTypes.push_back(read.getPadding().getType());
754 bool hasMask =
false;
755 if (read.getMask()) {
765 read,
"non-trivial permutation maps not supported");
766 VectorType maskType =
767 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
768 additionalResults.push_back(read.getMask());
769 additionalResultTypes.push_back(maskType);
773 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
774 rewriter, warpOp, additionalResults, additionalResultTypes,
776 distributedVal = newWarpOp.getResult(operandIndex);
780 for (int64_t i = 0, e = indices.size(); i < e; ++i)
781 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
786 bindDims(read.getContext(), d0, d1);
787 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
790 unsigned indexPos = indexExpr.getPosition();
791 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
792 int64_t scale = distributedType.getDimSize(vectorPos);
794 rewriter, read.getLoc(), d0 + scale * d1,
795 {newIndices[indexPos], delinearizedIds[vectorPos]});
799 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
802 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
804 auto newRead = rewriter.
create<vector::TransferReadOp>(
805 read.getLoc(), distributedVal.
getType(), read.getSource(), newIndices,
806 read.getPermutationMapAttr(), newPadding, newMask,
807 read.getInBoundsAttr());
818 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
821 newResultTypes.reserve(warpOp->getNumResults());
823 newYieldValues.reserve(warpOp->getNumResults());
826 auto yield = cast<gpu::YieldOp>(
827 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
837 for (
OpResult result : warpOp.getResults()) {
838 Value yieldOperand = yield.getOperand(result.getResultNumber());
839 auto it = dedupYieldOperandPositionMap.insert(
840 std::make_pair(yieldOperand, newResultTypes.size()));
841 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
842 if (result.use_empty() || !it.second)
844 newResultTypes.push_back(result.getType());
845 newYieldValues.push_back(yieldOperand);
848 if (yield.getNumOperands() == newYieldValues.size())
851 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
852 rewriter, warpOp, newYieldValues, newResultTypes);
855 newWarpOp.getBody()->walk([&](
Operation *op) {
862 newValues.reserve(warpOp->getNumResults());
863 for (
OpResult result : warpOp.getResults()) {
864 if (result.use_empty())
865 newValues.push_back(
Value());
868 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
879 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
883 auto yield = cast<gpu::YieldOp>(
884 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
886 unsigned resultIndex;
887 for (
OpOperand &operand : yield->getOpOperands()) {
896 valForwarded = operand.
get();
900 auto arg = dyn_cast<BlockArgument>(operand.
get());
901 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
903 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
906 valForwarded = warpOperand;
923 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
926 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
931 Location loc = broadcastOp.getLoc();
933 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
934 Value broadcastSrc = broadcastOp.getSource();
945 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
946 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
948 Value broadcasted = rewriter.
create<vector::BroadcastOp>(
949 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
960 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
963 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
970 auto castDistributedType =
971 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
972 VectorType castOriginalType = oldCastOp.getSourceVectorType();
973 VectorType castResultType = castDistributedType;
977 unsigned castDistributedRank = castDistributedType.getRank();
978 unsigned castOriginalRank = castOriginalType.getRank();
979 if (castDistributedRank < castOriginalRank) {
981 llvm::append_range(shape, castDistributedType.getShape());
982 castDistributedType =
987 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
988 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
991 Value newCast = rewriter.
create<vector::ShapeCastOp>(
992 oldCastOp.getLoc(), castResultType,
993 newWarpOp->getResult(newRetIndices[0]));
1019 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1022 getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1030 if (!llvm::all_of(mask->getOperands(), [&](
Value value) {
1031 return warpOp.isDefinedOutsideOfRegion(value);
1038 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1039 VectorType seqType = mask.getVectorType();
1047 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1048 warpOp.getWarpSize(), warpOp.getLaneid(),
1051 mask,
"cannot delinearize lane ID for distribution");
1052 assert(!delinearizedIds.empty());
1061 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1068 rewriter, loc, s1 - s0 * distShape[i],
1069 {delinearizedIds[i], mask.getOperand(i)});
1070 newOperands.push_back(maskDimIdx);
1074 rewriter.
create<vector::CreateMaskOp>(loc, distType, newOperands);
1085 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1088 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1093 VectorType extractSrcType = extractOp.getSourceVectorType();
1097 if (extractSrcType.getRank() <= 1) {
1103 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1110 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1111 rewriter, warpOp, {extractOp.getVector()},
1112 {extractOp.getSourceVectorType()}, newRetIndices);
1114 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1116 Value newExtract = rewriter.
create<vector::ExtractOp>(
1117 loc, distributedVec, extractOp.getMixedPosition());
1124 auto distributedType =
1125 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1126 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1127 int64_t distributedDim = -1;
1128 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1129 if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1132 assert(distributedDim == -1 &&
"found multiple distributed dims");
1136 assert(distributedDim != -1 &&
"could not find distributed dimension");
1137 (void)distributedDim;
1141 for (
int i = 0; i < distributedType.getRank(); ++i)
1142 newDistributedShape[i + extractOp.getNumIndices()] =
1143 distributedType.getDimSize(i);
1144 auto newDistributedType =
1145 VectorType::get(newDistributedShape, distributedType.getElementType());
1147 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1148 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1151 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1153 Value newExtract = rewriter.
create<vector::ExtractOp>(
1154 loc, distributedVec, extractOp.getMixedPosition());
1164 WarpOpExtractScalar(
MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1167 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1170 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1175 VectorType extractSrcType = extractOp.getSourceVectorType();
1177 if (extractSrcType.getRank() > 1) {
1179 extractOp,
"only 0-D or 1-D source supported for now");
1183 if (!extractSrcType.getElementType().isF32() &&
1184 !extractSrcType.getElementType().isInteger(32))
1186 extractOp,
"only f32/i32 element types are supported");
1187 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1188 Type elType = extractSrcType.getElementType();
1189 VectorType distributedVecType;
1190 if (!is0dOrVec1Extract) {
1191 assert(extractSrcType.getRank() == 1 &&
1192 "expected that extract src rank is 0 or 1");
1193 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1195 int64_t elementsPerLane =
1196 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1199 distributedVecType = extractSrcType;
1204 additionalResults.append(
1206 additionalResultTypes.append(
1211 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1212 rewriter, warpOp, additionalResults, additionalResultTypes,
1215 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1219 if (is0dOrVec1Extract) {
1223 rewriter.
create<vector::ExtractOp>(loc, distributedVec, indices);
1229 int64_t staticPos = extractOp.getStaticPosition()[0];
1231 ? (newWarpOp->getResult(newRetIndices[1]))
1235 int64_t elementsPerLane = distributedVecType.getShape()[0];
1239 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1242 elementsPerLane == 1
1243 ? rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult()
1245 sym0 % elementsPerLane, pos);
1247 rewriter.
create<vector::ExtractOp>(loc, distributedVec, newPos);
1250 Value shuffled = warpShuffleFromIdxFn(
1251 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1257 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1263 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1266 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1271 if (
auto pos = extractOp.getPosition()) {
1272 indices.push_back(pos);
1276 extractOp, extractOp.getVector(), indices);
1285 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1287 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1292 VectorType vecType = insertOp.getDestVectorType();
1293 VectorType distrType =
1294 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1297 if (vecType.getRank() > 1) {
1299 insertOp,
"only 0-D or 1-D source supported for now");
1304 insertOp.getValueToStore()};
1306 distrType, insertOp.getValueToStore().getType()};
1308 additionalResultTypes.append(
1313 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1314 rewriter, warpOp, additionalResults, additionalResultTypes,
1317 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1318 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1322 if (vecType.getRank() != 0) {
1323 int64_t staticPos = insertOp.getStaticPosition()[0];
1324 pos = ShapedType::isDynamic(staticPos)
1325 ? (newWarpOp->getResult(newRetIndices[2]))
1330 if (vecType == distrType) {
1334 indices.push_back(pos);
1336 newInsert = rewriter.
create<vector::InsertOp>(loc, newSource,
1337 distributedVec, indices);
1345 int64_t elementsPerLane = distrType.getShape()[0];
1349 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1352 rewriter, loc, sym0 % elementsPerLane, pos);
1353 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1354 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1358 loc, isInsertingLane,
1361 Value newInsert = builder.create<vector::InsertOp>(
1362 loc, newSource, distributedVec, newPos);
1363 builder.create<scf::YieldOp>(loc, newInsert);
1367 builder.create<scf::YieldOp>(loc, distributedVec);
1377 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1379 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1387 if (insertOp.getDestVectorType().getRank() <= 1) {
1393 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1397 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1398 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1399 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1402 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1403 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1404 Value newResult = rewriter.
create<vector::InsertOp>(
1405 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1412 auto distrDestType =
1413 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1414 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1415 int64_t distrDestDim = -1;
1416 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1417 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1420 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1424 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1427 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1435 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1436 if (distrSrcDim >= 0)
1437 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1443 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1444 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1445 {distrSrcType, distrDestType}, newRetIndices);
1447 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1448 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1452 if (distrSrcDim >= 0) {
1454 newResult = rewriter.
create<vector::InsertOp>(
1455 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1458 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1462 Value insertingLane = rewriter.
create<arith::ConstantIndexOp>(
1463 loc, newPos[distrDestDim] / elementsPerLane);
1464 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1465 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1467 newPos[distrDestDim] %= elementsPerLane;
1469 Value newInsert = builder.
create<vector::InsertOp>(
1470 loc, distributedSrc, distributedDest, newPos);
1471 builder.
create<scf::YieldOp>(loc, newInsert);
1474 builder.
create<scf::YieldOp>(loc, distributedDest);
1476 newResult = rewriter
1477 .
create<scf::IfOp>(loc, isInsertingLane,
1479 nonInsertingBuilder)
1490 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1493 getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1498 if (
auto pos = insertOp.getPosition()) {
1499 indices.push_back(pos);
1503 insertOp, insertOp.getSource(), insertOp.getDest(), indices);
1544 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1546 auto yield = cast<gpu::YieldOp>(
1547 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1549 Operation *lastNode = yield->getPrevNode();
1550 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1556 llvm::SmallSetVector<Value, 32> escapingValues;
1560 forOp.getBodyRegion(), [&](
OpOperand *operand) {
1561 Operation *parent = operand->get().getParentRegion()->getParentOp();
1562 if (warpOp->isAncestor(parent)) {
1563 if (!escapingValues.insert(operand->get()))
1565 Type distType = operand->get().getType();
1566 if (auto vecType = dyn_cast<VectorType>(distType)) {
1567 AffineMap map = distributionMapFn(operand->get());
1568 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1570 inputTypes.push_back(operand->get().getType());
1571 distTypes.push_back(distType);
1575 if (llvm::is_contained(distTypes,
Type{}))
1579 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1580 rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1582 yield = cast<gpu::YieldOp>(
1583 newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1588 for (
OpOperand &yieldOperand : yield->getOpOperands()) {
1591 auto forResult = cast<OpResult>(yieldOperand.
get());
1592 newOperands.push_back(
1594 yieldOperand.
set(forOp.getInitArgs()[forResult.getResultNumber()]);
1603 auto newForOp = rewriter.
create<scf::ForOp>(
1604 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1605 forOp.getStep(), newOperands);
1609 newForOp.getRegionIterArgs().end());
1611 forOp.getResultTypes().end());
1612 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1614 warpInput.push_back(newWarpOp.getResult(retIdx));
1615 argIndexMapping[escapingValues[i]] = warpInputType.size();
1616 warpInputType.push_back(inputTypes[i]);
1618 auto innerWarp = rewriter.
create<WarpExecuteOnLane0Op>(
1619 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1620 newWarpOp.getWarpSize(), warpInput, warpInputType);
1623 argMapping.push_back(newForOp.getInductionVar());
1624 for (
Value args : innerWarp.getBody()->getArguments()) {
1625 argMapping.push_back(args);
1627 argMapping.resize(forOp.getBody()->getNumArguments());
1629 for (
Value operand : forOp.getBody()->getTerminator()->getOperands())
1630 yieldOperands.push_back(operand);
1631 rewriter.
eraseOp(forOp.getBody()->getTerminator());
1632 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1634 rewriter.
create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
1636 if (!innerWarp.getResults().empty())
1637 rewriter.
create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1642 newForOp.getResult(res.index()));
1643 newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1647 auto it = argIndexMapping.find(operand.
get());
1648 if (it == argIndexMapping.end())
1650 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
1655 mlir::vector::moveScalarUniformCode(innerWarp);
1684 DistributedReductionFn distributedReductionFn,
1687 distributedReductionFn(std::move(distributedReductionFn)) {}
1689 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1692 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1698 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1700 if (vectorType.getRank() != 1)
1702 warpOp,
"Only rank 1 reductions can be distributed.");
1704 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1706 warpOp,
"Reduction vector dimension must match was size.");
1707 if (!reductionOp.getType().isIntOrFloat())
1709 warpOp,
"Reduction distribution currently only supports floats and "
1712 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1718 if (reductionOp.getAcc()) {
1719 yieldValues.push_back(reductionOp.getAcc());
1720 retTypes.push_back(reductionOp.getAcc().getType());
1723 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1724 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1728 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1731 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1732 reductionOp.getKind(), newWarpOp.getWarpSize());
1733 if (reductionOp.getAcc()) {
1735 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1736 newWarpOp.getResult(newRetIndices[1]));
1743 DistributedReductionFn distributedReductionFn;
1754 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1757 patterns.add<WarpOpTransferWrite>(
patterns.getContext(), distributionMapFn,
1758 maxNumElementsToExtract, benefit);
1761 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1763 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
1766 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1767 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1768 WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1769 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1771 patterns.add<WarpOpExtractScalar>(
patterns.getContext(), warpShuffleFromIdxFn,
1777 void mlir::vector::populateDistributeReduction(
1779 const DistributedReductionFn &distributedReductionFn,
1781 patterns.add<WarpOpReduction>(
patterns.getContext(), distributedReductionFn,
1788 return llvm::all_of(op->
getOperands(), definedOutside) &&
1792 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1793 Block *body = warpOp.getBody();
1796 llvm::SmallSetVector<Operation *, 8> opsToMove;
1799 auto isDefinedOutsideOfBody = [&](
Value value) {
1801 return (definingOp && opsToMove.count(definingOp)) ||
1802 warpOp.isDefinedOutsideOfRegion(value);
1808 bool hasVectorResult = llvm::any_of(op.
getResults(), [](
Value result) {
1809 return isa<VectorType>(result.getType());
1811 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
1812 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.