22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVectorExtras.h"
24 #include "llvm/Support/FormatVariadic.h"
43 VectorType distributedType) {
49 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
54 distributedType.getContext());
62 VectorType distributedType) {
63 assert(sequentialType.getRank() == distributedType.getRank() &&
64 "sequential and distributed vector types must have the same rank");
65 int64_t distributedDim = -1;
66 for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
67 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
70 assert(distributedDim == -1 &&
"found multiple distributed dims");
74 return distributedDim;
84 struct DistributedLoadStoreHelper {
85 DistributedLoadStoreHelper(
Value sequentialVal,
Value distributedVal,
87 : sequentialVal(sequentialVal), distributedVal(distributedVal),
88 laneId(laneId), zero(zero) {
89 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.
getType());
90 distributedVectorType = dyn_cast<VectorType>(distributedVal.
getType());
91 if (sequentialVectorType && distributedVectorType)
97 int64_t distributedSize = distributedVectorType.getDimSize(index);
99 return b.
createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
112 assert((val == distributedVal || val == sequentialVal) &&
113 "Must store either the preregistered distributed or the "
114 "preregistered sequential value.");
116 if (!isa<VectorType>(val.
getType()))
117 return b.
create<memref::StoreOp>(loc, val, buffer, zero);
121 int64_t rank = sequentialVectorType.getRank();
123 if (val == distributedVal) {
124 for (
auto dimExpr : distributionMap.getResults()) {
125 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126 indices[index] = buildDistributedOffset(b, loc, index);
130 return b.
create<vector::TransferWriteOp>(
131 loc, val, buffer, indices,
158 if (!isa<VectorType>(type))
159 return b.
create<memref::LoadOp>(loc, buffer, zero);
164 assert((type == distributedVectorType || type == sequentialVectorType) &&
165 "Must store either the preregistered distributed or the "
166 "preregistered sequential type.");
168 if (type == distributedVectorType) {
169 for (
auto dimExpr : distributionMap.getResults()) {
170 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171 indices[index] = buildDistributedOffset(b, loc, index);
175 return b.
create<vector::TransferReadOp>(
176 loc, cast<VectorType>(type), buffer, indices, std::nullopt,
180 Value sequentialVal, distributedVal, laneId, zero;
181 VectorType sequentialVectorType, distributedVectorType;
195 return rewriter.
create(res);
234 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
236 assert(warpOp.getBodyRegion().hasOneBlock() &&
237 "expected WarpOp with single block");
238 Block *warpOpBody = &warpOp.getBodyRegion().
front();
246 Value c0 = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
248 loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
249 auto ifOp = rewriter.
create<scf::IfOp>(loc, isLane0,
251 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
258 Value distributedVal = it.value();
259 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
260 warpOp.getLaneid(), c0);
264 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
267 helper.buildStore(rewriter, loc, distributedVal, buffer);
270 bbArgReplacements.push_back(
271 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
275 if (!warpOp.getArgs().empty()) {
277 options.warpSyncronizationFn(loc, rewriter, warpOp);
281 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
288 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
289 Location yieldLoc = yieldOp.getLoc();
291 Value sequentialVal = it.value();
292 Value distributedVal = warpOp->getResult(it.index());
293 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
294 warpOp.getLaneid(), c0);
298 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
304 helper.buildStore(rewriter, loc, sequentialVal, buffer);
315 replacements.push_back(
316 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
320 if (!yieldOp.getOperands().empty()) {
322 options.warpSyncronizationFn(loc, rewriter, warpOp);
328 rewriter.
create<scf::YieldOp>(yieldLoc);
331 rewriter.
replaceOp(warpOp, replacements);
348 static VectorType getDistributedType(VectorType originalType,
AffineMap map,
353 if (targetShape[position] % warpSize != 0) {
354 if (warpSize % targetShape[position] != 0) {
357 warpSize /= targetShape[position];
358 targetShape[position] = 1;
361 targetShape[position] = targetShape[position] / warpSize;
368 VectorType targetType =
396 maxNumElementsToExtract(maxNumElementsToExtract) {}
401 vector::TransferWriteOp writeOp,
402 WarpExecuteOnLane0Op warpOp)
const {
403 VectorType writtenVectorType = writeOp.getVectorType();
407 if (writtenVectorType.getRank() == 0)
411 AffineMap map = distributionMapFn(writeOp.getVector());
412 VectorType targetType =
413 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
419 if (writeOp.getMask()) {
426 if (!writeOp.getPermutationMap().isMinorIdentity())
429 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
434 vector::TransferWriteOp newWriteOp =
435 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
439 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
446 for (
auto [seqSize, distSize] :
447 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
448 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
449 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
453 delinearized = rewriter
454 .
create<mlir::affine::AffineDelinearizeIndexOp>(
455 newWarpOp.getLoc(), newWarpOp.getLaneid(),
461 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
467 newWriteOp.getIndices().end());
470 bindDims(newWarpOp.getContext(), d0, d1);
471 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
474 unsigned indexPos = indexExpr.getPosition();
475 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
476 Value laneId = delinearized[vectorPos];
480 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
482 newWriteOp.getIndicesMutable().assign(indices);
489 vector::TransferWriteOp writeOp,
490 WarpExecuteOnLane0Op warpOp)
const {
492 VectorType vecType = writeOp.getVectorType();
494 if (vecType.getNumElements() > maxNumElementsToExtract) {
498 "writes more elements ({0}) than allowed to extract ({1})",
499 vecType.getNumElements(), maxNumElementsToExtract));
503 if (llvm::all_of(warpOp.getOps(),
504 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
510 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
511 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
515 auto secondWarpOp = rewriter.
create<WarpExecuteOnLane0Op>(
516 loc,
TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
517 Block &body = secondWarpOp.getBodyRegion().
front();
520 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
521 newWriteOp.getValueToStoreMutable().assign(
522 newWarpOp.getResult(newRetIndices[0]));
524 rewriter.
create<gpu::YieldOp>(newWarpOp.getLoc());
528 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
530 auto yield = cast<gpu::YieldOp>(
531 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
532 Operation *lastNode = yield->getPrevNode();
533 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
537 Value maybeMask = writeOp.getMask();
538 if (!llvm::all_of(writeOp->getOperands(), [&](
Value value) {
539 return writeOp.getVector() == value ||
540 (maybeMask && maybeMask == value) ||
541 warpOp.isDefinedOutsideOfRegion(value);
545 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
549 if (writeOp.getMask())
552 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
562 vector::TransferWriteOp cloneWriteOp(
RewriterBase &rewriter,
563 WarpExecuteOnLane0Op warpOp,
564 vector::TransferWriteOp writeOp,
565 VectorType targetType,
566 VectorType maybeMaskType)
const {
567 assert(writeOp->getParentOp() == warpOp &&
568 "write must be nested immediately under warp");
571 WarpExecuteOnLane0Op newWarpOp;
573 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
574 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
575 TypeRange{targetType, maybeMaskType}, newRetIndices);
577 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
578 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
583 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
585 newWriteOp.getValueToStoreMutable().assign(
586 newWarpOp.getResult(newRetIndices[0]));
588 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
593 unsigned maxNumElementsToExtract = 1;
616 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
626 Value distributedVal = warpOp.getResult(operandIndex);
632 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
634 auto operandType = cast<VectorType>(operand.get().getType());
638 auto operandType = operand.get().getType();
639 assert(!isa<VectorType>(operandType) &&
640 "unexpected yield of vector from op with scalar result type");
641 targetType = operandType;
643 retTypes.push_back(targetType);
644 yieldValues.push_back(operand.get());
647 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
648 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
652 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
653 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
658 rewriter, loc, elementWise, newOperands,
659 {newWarpOp.getResult(operandIndex).
getType()});
682 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
685 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
689 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
698 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
701 Value distConstant = rewriter.
create<arith::ConstantOp>(loc, newAttr);
728 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
735 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
739 warpOp,
"warp result is not a vector.transfer_read op");
743 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
745 read,
"source must be defined outside of the region");
748 Value distributedVal = warpOp.getResult(operandIndex);
751 read.getIndices().end());
752 auto sequentialType = cast<VectorType>(read.getResult().getType());
753 auto distributedType = cast<VectorType>(distributedVal.
getType());
760 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
761 distributedType.getShape(), warpOp.getWarpSize(),
762 warpOp.getLaneid(), delinearizedIds)) {
764 read,
"cannot delinearize lane ID for distribution");
766 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
773 additionalResults.push_back(read.getPadding());
774 additionalResultTypes.push_back(read.getPadding().getType());
776 bool hasMask =
false;
777 if (read.getMask()) {
787 read,
"non-trivial permutation maps not supported");
788 VectorType maskType =
789 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
790 additionalResults.push_back(read.getMask());
791 additionalResultTypes.push_back(maskType);
795 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
796 rewriter, warpOp, additionalResults, additionalResultTypes,
798 distributedVal = newWarpOp.getResult(operandIndex);
802 for (int64_t i = 0, e = indices.size(); i < e; ++i)
803 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
808 bindDims(read.getContext(), d0, d1);
809 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
812 unsigned indexPos = indexExpr.getPosition();
813 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
814 int64_t scale = distributedType.getDimSize(vectorPos);
816 rewriter, read.getLoc(), d0 + scale * d1,
817 {newIndices[indexPos], delinearizedIds[vectorPos]});
821 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
824 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
826 auto newRead = rewriter.
create<vector::TransferReadOp>(
827 read.getLoc(), distributedVal.
getType(), read.getBase(), newIndices,
828 read.getPermutationMapAttr(), newPadding, newMask,
829 read.getInBoundsAttr());
840 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
843 newResultTypes.reserve(warpOp->getNumResults());
845 newYieldValues.reserve(warpOp->getNumResults());
848 auto yield = cast<gpu::YieldOp>(
849 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
859 for (
OpResult result : warpOp.getResults()) {
860 Value yieldOperand = yield.getOperand(result.getResultNumber());
861 auto it = dedupYieldOperandPositionMap.insert(
862 std::make_pair(yieldOperand, newResultTypes.size()));
863 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
864 if (result.use_empty() || !it.second)
866 newResultTypes.push_back(result.getType());
867 newYieldValues.push_back(yieldOperand);
870 if (yield.getNumOperands() == newYieldValues.size())
873 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
874 rewriter, warpOp, newYieldValues, newResultTypes);
877 newWarpOp.getBody()->walk([&](
Operation *op) {
884 newValues.reserve(warpOp->getNumResults());
885 for (
OpResult result : warpOp.getResults()) {
886 if (result.use_empty())
887 newValues.push_back(
Value());
890 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
901 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
903 auto yield = cast<gpu::YieldOp>(
904 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
906 unsigned resultIndex;
907 for (
OpOperand &operand : yield->getOpOperands()) {
916 valForwarded = operand.
get();
920 auto arg = dyn_cast<BlockArgument>(operand.
get());
921 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
923 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
926 valForwarded = warpOperand;
943 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
946 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
951 Location loc = broadcastOp.getLoc();
953 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
954 Value broadcastSrc = broadcastOp.getSource();
965 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
966 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
968 Value broadcasted = rewriter.
create<vector::BroadcastOp>(
969 loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
980 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
983 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
990 auto castDistributedType =
991 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
992 VectorType castOriginalType = oldCastOp.getSourceVectorType();
993 VectorType castResultType = castDistributedType;
997 unsigned castDistributedRank = castDistributedType.getRank();
998 unsigned castOriginalRank = castOriginalType.getRank();
999 if (castDistributedRank < castOriginalRank) {
1001 llvm::append_range(shape, castDistributedType.getShape());
1002 castDistributedType =
1007 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1008 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1011 Value newCast = rewriter.
create<vector::ShapeCastOp>(
1012 oldCastOp.getLoc(), castResultType,
1013 newWarpOp->getResult(newRetIndices[0]));
1039 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1042 getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1050 if (!llvm::all_of(mask->getOperands(), [&](
Value value) {
1051 return warpOp.isDefinedOutsideOfRegion(value);
1058 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1059 VectorType seqType = mask.getVectorType();
1067 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1068 warpOp.getWarpSize(), warpOp.getLaneid(),
1071 mask,
"cannot delinearize lane ID for distribution");
1072 assert(!delinearizedIds.empty());
1081 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1088 rewriter, loc, s1 - s0 * distShape[i],
1089 {delinearizedIds[i], mask.getOperand(i)});
1090 newOperands.push_back(maskDimIdx);
1094 rewriter.
create<vector::CreateMaskOp>(loc, distType, newOperands);
1129 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1132 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1138 auto distributedType =
1139 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1142 if (distributedType.getRank() < 2)
1144 insertOp,
"result vector type must be 2D or higher");
1147 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1148 int64_t destDistributedDim =
1150 assert(destDistributedDim != -1 &&
"could not find distributed dimension");
1152 VectorType srcType = insertOp.getSourceVectorType();
1153 VectorType destType = insertOp.getDestVectorType();
1158 int64_t sourceDistributedDim =
1159 destDistributedDim - (destType.getRank() - srcType.getRank());
1160 if (sourceDistributedDim < 0)
1163 "distributed dimension must be in the last k dims of dest vector");
1165 if (srcType.getDimSize(sourceDistributedDim) !=
1166 destType.getDimSize(destDistributedDim))
1168 insertOp,
"distributed dimension must be fully inserted");
1170 insertOp.getSourceVectorType().getShape());
1171 newSourceDistShape[sourceDistributedDim] =
1172 distributedType.getDimSize(destDistributedDim);
1174 VectorType::get(newSourceDistShape, distributedType.getElementType());
1175 VectorType newDestTy = distributedType;
1177 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1178 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1179 {newSourceTy, newDestTy}, newRetIndices);
1181 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1182 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1185 Value newInsert = rewriter.
create<vector::InsertStridedSliceOp>(
1186 insertOp.getLoc(), distributedDest.
getType(), distributedSource,
1187 distributedDest, insertOp.getOffsets(), insertOp.getStrides());
1217 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1220 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1226 auto distributedType =
1227 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1230 if (distributedType.getRank() < 2)
1232 extractOp,
"result vector type must be 2D or higher");
1235 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1237 assert(distributedDim != -1 &&
"could not find distributed dimension");
1239 int64_t numOfExtractedDims =
1240 static_cast<int64_t
>(extractOp.getSizes().size());
1247 if (distributedDim < numOfExtractedDims) {
1248 int64_t distributedDimOffset =
1249 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1251 int64_t distributedDimSize =
1252 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1254 if (distributedDimOffset != 0 ||
1255 distributedDimSize != yieldedType.getDimSize(distributedDim))
1257 extractOp,
"distributed dimension must be fully extracted");
1260 extractOp.getSourceVectorType().getShape());
1261 newDistributedShape[distributedDim] =
1262 distributedType.getDimSize(distributedDim);
1263 auto newDistributedType =
1264 VectorType::get(newDistributedShape, distributedType.getElementType());
1266 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1267 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1271 extractOp.getSizes(), [](
Attribute attr) { return attr; });
1273 if (distributedDim <
static_cast<int64_t
>(distributedSizes.size()))
1275 distributedType.getDimSize(distributedDim));
1279 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1280 Value newExtract = rewriter.
create<vector::ExtractStridedSliceOp>(
1281 extractOp.getLoc(), distributedType, distributedVec,
1282 extractOp.getOffsets(),
1284 extractOp.getStrides());
1295 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1298 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1303 VectorType extractSrcType = extractOp.getSourceVectorType();
1307 if (extractSrcType.getRank() <= 1) {
1313 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1320 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1321 rewriter, warpOp, {extractOp.getVector()},
1322 {extractOp.getSourceVectorType()}, newRetIndices);
1324 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1326 Value newExtract = rewriter.
create<vector::ExtractOp>(
1327 loc, distributedVec, extractOp.getMixedPosition());
1334 auto distributedType =
1335 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1336 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1338 assert(distributedDim != -1 &&
"could not find distributed dimension");
1339 (void)distributedDim;
1343 for (
int i = 0; i < distributedType.getRank(); ++i)
1344 newDistributedShape[i + extractOp.getNumIndices()] =
1345 distributedType.getDimSize(i);
1346 auto newDistributedType =
1347 VectorType::get(newDistributedShape, distributedType.getElementType());
1349 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1350 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1353 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1355 Value newExtract = rewriter.
create<vector::ExtractOp>(
1356 loc, distributedVec, extractOp.getMixedPosition());
1366 WarpOpExtractScalar(
MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1369 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1372 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1377 VectorType extractSrcType = extractOp.getSourceVectorType();
1379 if (extractSrcType.getRank() > 1) {
1381 extractOp,
"only 0-D or 1-D source supported for now");
1385 if (!extractSrcType.getElementType().isF32() &&
1386 !extractSrcType.getElementType().isInteger(32))
1388 extractOp,
"only f32/i32 element types are supported");
1389 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1390 Type elType = extractSrcType.getElementType();
1391 VectorType distributedVecType;
1392 if (!is0dOrVec1Extract) {
1393 assert(extractSrcType.getRank() == 1 &&
1394 "expected that extract src rank is 0 or 1");
1395 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1397 int64_t elementsPerLane =
1398 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1401 distributedVecType = extractSrcType;
1406 additionalResults.append(
1408 additionalResultTypes.append(
1413 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1414 rewriter, warpOp, additionalResults, additionalResultTypes,
1417 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1421 if (is0dOrVec1Extract) {
1425 rewriter.
create<vector::ExtractOp>(loc, distributedVec, indices);
1431 int64_t staticPos = extractOp.getStaticPosition()[0];
1433 ? (newWarpOp->getResult(newRetIndices[1]))
1437 int64_t elementsPerLane = distributedVecType.getShape()[0];
1441 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1444 elementsPerLane == 1
1445 ? rewriter.
create<arith::ConstantIndexOp>(loc, 0).getResult()
1447 sym0 % elementsPerLane, pos);
1449 rewriter.
create<vector::ExtractOp>(loc, distributedVec, newPos);
1452 Value shuffled = warpShuffleFromIdxFn(
1453 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1459 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1466 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1468 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1473 VectorType vecType = insertOp.getDestVectorType();
1474 VectorType distrType =
1475 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1478 if (vecType.getRank() > 1) {
1480 insertOp,
"only 0-D or 1-D source supported for now");
1485 insertOp.getValueToStore()};
1487 distrType, insertOp.getValueToStore().getType()};
1489 additionalResultTypes.append(
1494 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1495 rewriter, warpOp, additionalResults, additionalResultTypes,
1498 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1499 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1503 if (vecType.getRank() != 0) {
1504 int64_t staticPos = insertOp.getStaticPosition()[0];
1505 pos = ShapedType::isDynamic(staticPos)
1506 ? (newWarpOp->getResult(newRetIndices[2]))
1511 if (vecType == distrType) {
1515 indices.push_back(pos);
1517 newInsert = rewriter.
create<vector::InsertOp>(loc, newSource,
1518 distributedVec, indices);
1526 int64_t elementsPerLane = distrType.getShape()[0];
1530 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1533 rewriter, loc, sym0 % elementsPerLane, pos);
1534 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1535 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1539 loc, isInsertingLane,
1542 Value newInsert = builder.create<vector::InsertOp>(
1543 loc, newSource, distributedVec, newPos);
1544 builder.create<scf::YieldOp>(loc, newInsert);
1548 builder.create<scf::YieldOp>(loc, distributedVec);
1558 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1560 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1568 if (insertOp.getDestVectorType().getRank() <= 1) {
1574 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1578 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1579 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1580 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1583 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1584 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1585 Value newResult = rewriter.
create<vector::InsertOp>(
1586 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1593 auto distrDestType =
1594 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1595 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1596 int64_t distrDestDim = -1;
1597 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1598 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1601 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1605 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1608 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1616 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1617 if (distrSrcDim >= 0)
1618 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1624 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1625 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1626 {distrSrcType, distrDestType}, newRetIndices);
1628 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1629 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1633 if (distrSrcDim >= 0) {
1635 newResult = rewriter.
create<vector::InsertOp>(
1636 loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1639 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1643 Value insertingLane = rewriter.
create<arith::ConstantIndexOp>(
1644 loc, newPos[distrDestDim] / elementsPerLane);
1645 Value isInsertingLane = rewriter.
create<arith::CmpIOp>(
1646 loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1648 newPos[distrDestDim] %= elementsPerLane;
1650 Value newInsert = builder.
create<vector::InsertOp>(
1651 loc, distributedSrc, distributedDest, newPos);
1652 builder.
create<scf::YieldOp>(loc, newInsert);
1655 builder.
create<scf::YieldOp>(loc, distributedDest);
1657 newResult = rewriter
1658 .
create<scf::IfOp>(loc, isInsertingLane,
1660 nonInsertingBuilder)
1705 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1707 auto warpOpYield = cast<gpu::YieldOp>(
1708 warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1710 Operation *lastNode = warpOpYield->getPrevNode();
1711 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1716 llvm::SmallSetVector<Value, 32> escapingValues;
1720 forOp.getBodyRegion(), [&](
OpOperand *operand) {
1721 Operation *parent = operand->get().getParentRegion()->getParentOp();
1722 if (warpOp->isAncestor(parent)) {
1723 if (!escapingValues.insert(operand->get()))
1725 Type distType = operand->get().getType();
1726 if (auto vecType = dyn_cast<VectorType>(distType)) {
1727 AffineMap map = distributionMapFn(operand->get());
1728 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1730 escapingValueInputTypes.push_back(operand->get().getType());
1731 escapingValueDistTypes.push_back(distType);
1735 if (llvm::is_contained(escapingValueDistTypes,
Type{}))
1748 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
1749 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
1750 for (
OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1753 nonForYieldedValues.push_back(yieldOperand.
get());
1757 OpResult forResult = cast<OpResult>(yieldOperand.
get());
1762 if (!isa<VectorType>(forResult.
getType()))
1764 VectorType distType = cast<VectorType>(
1766 forResultDistTypes[forResultNumber] = distType;
1776 newWarpOpYieldValues.push_back(initArg);
1778 Type distType = initArg.getType();
1779 if (
auto vecType = dyn_cast<VectorType>(distType)) {
1783 AffineMap map = distributionMapFn(initArg);
1784 distType = forResultDistTypes.count(i)
1785 ? forResultDistTypes[i]
1786 : getDistributedType(vecType, map, warpOp.getWarpSize());
1788 newWarpOpDistTypes.push_back(distType);
1791 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
1792 escapingValues.begin(), escapingValues.end());
1793 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
1794 escapingValueDistTypes.begin(),
1795 escapingValueDistTypes.end());
1800 llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
1802 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
1803 nonForResultMapping[i] = newWarpOpYieldValues.size();
1804 newWarpOpYieldValues.push_back(v);
1805 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
1808 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1809 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1813 const unsigned escapingValuesStartIdx =
1814 forOp.getInitArgs().size();
1817 for (
size_t i = 0; i < escapingValuesStartIdx; ++i)
1818 newForOpOperands.push_back(newWarpOp.getResult(i));
1823 auto newForOp = rewriter.
create<scf::ForOp>(
1824 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1825 forOp.getStep(), newForOpOperands);
1832 newForOp.getRegionIterArgs().end());
1834 forOp.getResultTypes().end());
1838 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1839 for (
size_t i = escapingValuesStartIdx;
1840 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
1841 innerWarpInput.push_back(newWarpOp.getResult(i));
1842 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1843 innerWarpInputType.size();
1844 innerWarpInputType.push_back(
1845 escapingValueInputTypes[i - escapingValuesStartIdx]);
1848 auto innerWarp = rewriter.
create<WarpExecuteOnLane0Op>(
1849 newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1850 newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
1854 argMapping.push_back(newForOp.getInductionVar());
1855 for (
Value args : innerWarp.getBody()->getArguments())
1856 argMapping.push_back(args);
1858 argMapping.resize(forOp.getBody()->getNumArguments());
1860 for (
Value operand : forOp.getBody()->getTerminator()->getOperands())
1861 yieldOperands.push_back(operand);
1863 rewriter.
eraseOp(forOp.getBody()->getTerminator());
1864 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1869 rewriter.
create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
1873 if (!innerWarp.getResults().empty())
1874 rewriter.
create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1878 for (
auto [origIdx, newIdx] : forResultMapping)
1880 newForOp.getResult(newIdx), newForOp);
1883 for (
auto [origIdx, newIdx] : nonForResultMapping)
1885 newWarpOp.getResult(newIdx));
1894 auto it = argIndexMapping.find(operand.
get());
1895 if (it == argIndexMapping.end())
1897 operand.
set(innerWarp.getBodyRegion().getArgument(it->second));
1902 mlir::vector::moveScalarUniformCode(innerWarp);
1931 DistributedReductionFn distributedReductionFn,
1934 distributedReductionFn(std::move(distributedReductionFn)) {}
1936 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1939 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1945 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1947 if (vectorType.getRank() != 1)
1949 warpOp,
"Only rank 1 reductions can be distributed.");
1951 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1953 warpOp,
"Reduction vector dimension must match was size.");
1954 if (!reductionOp.getType().isIntOrFloat())
1956 warpOp,
"Reduction distribution currently only supports floats and "
1959 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1965 if (reductionOp.getAcc()) {
1966 yieldValues.push_back(reductionOp.getAcc());
1967 retTypes.push_back(reductionOp.getAcc().getType());
1970 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1971 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1975 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1978 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1979 reductionOp.getKind(), newWarpOp.getWarpSize());
1980 if (reductionOp.getAcc()) {
1982 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1983 newWarpOp.getResult(newRetIndices[1]));
1990 DistributedReductionFn distributedReductionFn;
2001 void mlir::vector::populateDistributeTransferWriteOpPatterns(
2004 patterns.add<WarpOpTransferWrite>(
patterns.getContext(), distributionMapFn,
2005 maxNumElementsToExtract, benefit);
2008 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2010 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
2014 .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2015 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
2016 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
2017 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
2019 patterns.add<WarpOpExtractScalar>(
patterns.getContext(), warpShuffleFromIdxFn,
2025 void mlir::vector::populateDistributeReduction(
2027 const DistributedReductionFn &distributedReductionFn,
2029 patterns.add<WarpOpReduction>(
patterns.getContext(), distributedReductionFn,
2036 return llvm::all_of(op->
getOperands(), definedOutside) &&
2040 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2041 Block *body = warpOp.getBody();
2044 llvm::SmallSetVector<Operation *, 8> opsToMove;
2047 auto isDefinedOutsideOfBody = [&](
Value value) {
2049 return (definingOp && opsToMove.count(definingOp)) ||
2050 warpOp.isDefinedOutsideOfRegion(value);
2056 bool hasVectorResult = llvm::any_of(op.
getResults(), [](
Value result) {
2057 return isa<VectorType>(result.getType());
2059 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
2060 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.