22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVectorExtras.h"
24 #include "llvm/Support/FormatVariadic.h"
43 VectorType distributedType) {
49 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
54 distributedType.getContext());
62 VectorType distributedType) {
63 assert(sequentialType.getRank() == distributedType.getRank() &&
64 "sequential and distributed vector types must have the same rank");
65 int64_t distributedDim = -1;
66 for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
67 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
70 assert(distributedDim == -1 &&
"found multiple distributed dims");
74 return distributedDim;
84 struct DistributedLoadStoreHelper {
85 DistributedLoadStoreHelper(
Value sequentialVal,
Value distributedVal,
87 : sequentialVal(sequentialVal), distributedVal(distributedVal),
88 laneId(laneId), zero(zero) {
89 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.
getType());
90 distributedVectorType = dyn_cast<VectorType>(distributedVal.
getType());
91 if (sequentialVectorType && distributedVectorType)
97 int64_t distributedSize = distributedVectorType.getDimSize(index);
99 return b.
createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
112 assert((val == distributedVal || val == sequentialVal) &&
113 "Must store either the preregistered distributed or the "
114 "preregistered sequential value.");
116 if (!isa<VectorType>(val.
getType()))
117 return memref::StoreOp::create(b, loc, val, buffer, zero);
121 int64_t rank = sequentialVectorType.getRank();
123 if (val == distributedVal) {
124 for (
auto dimExpr : distributionMap.getResults()) {
125 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126 indices[index] = buildDistributedOffset(b, loc, index);
130 return vector::TransferWriteOp::create(
131 b, loc, val, buffer, indices,
158 if (!isa<VectorType>(type))
159 return memref::LoadOp::create(b, loc, buffer, zero);
164 assert((type == distributedVectorType || type == sequentialVectorType) &&
165 "Must store either the preregistered distributed or the "
166 "preregistered sequential type.");
168 if (type == distributedVectorType) {
169 for (
auto dimExpr : distributionMap.getResults()) {
170 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171 indices[index] = buildDistributedOffset(b, loc, index);
175 return vector::TransferReadOp::create(
176 b, loc, cast<VectorType>(type), buffer, indices,
181 Value sequentialVal, distributedVal, laneId, zero;
182 VectorType sequentialVectorType, distributedVectorType;
196 return rewriter.
create(res);
235 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
237 assert(warpOp.getBodyRegion().hasOneBlock() &&
238 "expected WarpOp with single block");
239 Block *warpOpBody = &warpOp.getBodyRegion().
front();
248 Value isLane0 = arith::CmpIOp::create(
249 rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
250 auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
252 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
259 Value distributedVal = it.value();
260 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
261 warpOp.getLaneid(), c0);
265 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
268 helper.buildStore(rewriter, loc, distributedVal, buffer);
271 bbArgReplacements.push_back(
272 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
276 if (!warpOp.getArgs().empty()) {
278 options.warpSyncronizationFn(loc, rewriter, warpOp);
282 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
289 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
290 Location yieldLoc = yieldOp.getLoc();
292 Value sequentialVal = it.value();
293 Value distributedVal = warpOp->getResult(it.index());
294 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
295 warpOp.getLaneid(), c0);
299 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
305 helper.buildStore(rewriter, loc, sequentialVal, buffer);
316 replacements.push_back(
317 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
321 if (!yieldOp.getOperands().empty()) {
323 options.warpSyncronizationFn(loc, rewriter, warpOp);
329 scf::YieldOp::create(rewriter, yieldLoc);
332 rewriter.
replaceOp(warpOp, replacements);
349 static VectorType getDistributedType(VectorType originalType,
AffineMap map,
354 if (targetShape[position] % warpSize != 0) {
355 if (warpSize % targetShape[position] != 0) {
358 warpSize /= targetShape[position];
359 targetShape[position] = 1;
362 targetShape[position] = targetShape[position] / warpSize;
369 VectorType targetType =
380 getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp,
Region &innerRegion,
382 llvm::SmallSetVector<Value, 32> escapingValues;
385 if (innerRegion.
empty())
386 return {std::move(escapingValues), std::move(escapingValueTypes),
387 std::move(escapingValueDistTypes)};
389 Operation *parent = operand->
get().getParentRegion()->getParentOp();
390 if (warpOp->isAncestor(parent)) {
391 if (!escapingValues.insert(operand->get()))
393 Type distType = operand->get().getType();
394 if (auto vecType = dyn_cast<VectorType>(distType)) {
395 AffineMap map = distributionMapFn(operand->get());
396 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
398 escapingValueTypes.push_back(operand->
get().
getType());
399 escapingValueDistTypes.push_back(distType);
402 return {std::move(escapingValues), std::move(escapingValueTypes),
403 std::move(escapingValueDistTypes)};
429 maxNumElementsToExtract(maxNumElementsToExtract) {}
434 vector::TransferWriteOp writeOp,
435 WarpExecuteOnLane0Op warpOp)
const {
436 VectorType writtenVectorType = writeOp.getVectorType();
440 if (writtenVectorType.getRank() == 0)
444 AffineMap map = distributionMapFn(writeOp.getVector());
445 VectorType targetType =
446 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
452 if (writeOp.getMask()) {
459 if (!writeOp.getPermutationMap().isMinorIdentity())
462 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
467 vector::TransferWriteOp newWriteOp =
468 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
472 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
479 for (
auto [seqSize, distSize] :
480 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
481 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
482 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
486 delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
487 rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
493 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
499 newWriteOp.getIndices().end());
502 bindDims(newWarpOp.getContext(), d0, d1);
503 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
506 unsigned indexPos = indexExpr.getPosition();
507 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
508 Value laneId = delinearized[vectorPos];
512 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
514 newWriteOp.getIndicesMutable().assign(indices);
521 vector::TransferWriteOp writeOp,
522 WarpExecuteOnLane0Op warpOp)
const {
524 VectorType vecType = writeOp.getVectorType();
526 if (vecType.getNumElements() > maxNumElementsToExtract) {
530 "writes more elements ({0}) than allowed to extract ({1})",
531 vecType.getNumElements(), maxNumElementsToExtract));
535 if (llvm::all_of(warpOp.getOps(),
536 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
542 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
543 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
547 auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc,
TypeRange(),
548 newWarpOp.getLaneid(),
549 newWarpOp.getWarpSize());
550 Block &body = secondWarpOp.getBodyRegion().
front();
553 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
554 newWriteOp.getValueToStoreMutable().assign(
555 newWarpOp.getResult(newRetIndices[0]));
557 gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
561 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
563 gpu::YieldOp yield = warpOp.getTerminator();
564 Operation *lastNode = yield->getPrevNode();
565 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
569 Value maybeMask = writeOp.getMask();
570 if (!llvm::all_of(writeOp->getOperands(), [&](
Value value) {
571 return writeOp.getVector() == value ||
572 (maybeMask && maybeMask == value) ||
573 warpOp.isDefinedOutsideOfRegion(value);
577 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
581 if (writeOp.getMask())
584 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
594 vector::TransferWriteOp cloneWriteOp(
RewriterBase &rewriter,
595 WarpExecuteOnLane0Op warpOp,
596 vector::TransferWriteOp writeOp,
597 VectorType targetType,
598 VectorType maybeMaskType)
const {
599 assert(writeOp->getParentOp() == warpOp &&
600 "write must be nested immediately under warp");
603 WarpExecuteOnLane0Op newWarpOp;
605 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
606 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
607 TypeRange{targetType, maybeMaskType}, newRetIndices);
609 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
610 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
615 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
617 newWriteOp.getValueToStoreMutable().assign(
618 newWarpOp.getResult(newRetIndices[0]));
620 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
625 unsigned maxNumElementsToExtract = 1;
648 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
658 Value distributedVal = warpOp.getResult(operandIndex);
664 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
666 auto operandType = cast<VectorType>(operand.
get().
getType());
671 assert(!isa<VectorType>(operandType) &&
672 "unexpected yield of vector from op with scalar result type");
673 targetType = operandType;
675 retTypes.push_back(targetType);
676 yieldValues.push_back(operand.
get());
679 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
680 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
684 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
685 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
690 rewriter, loc, elementWise, newOperands,
691 {newWarpOp.getResult(operandIndex).
getType()});
714 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
717 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
721 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
730 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
733 Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
762 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
765 getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
771 if (resTy.getNumElements() !=
static_cast<int64_t
>(warpOp.getWarpSize()))
774 llvm::formatv(
"Expected result size ({0}) to be of warp size ({1})",
775 resTy.getNumElements(), warpOp.getWarpSize()));
776 VectorType newVecTy =
777 cast<VectorType>(warpOp.getResult(operandIdx).getType());
779 Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
780 newVecTy, warpOp.getLaneid());
806 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
813 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
817 warpOp,
"warp result is not a vector.transfer_read op");
821 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
823 read,
"source must be defined outside of the region");
826 Value distributedVal = warpOp.getResult(operandIndex);
829 read.getIndices().end());
830 auto sequentialType = cast<VectorType>(read.getResult().getType());
831 auto distributedType = cast<VectorType>(distributedVal.
getType());
838 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
839 distributedType.getShape(), warpOp.getWarpSize(),
840 warpOp.getLaneid(), delinearizedIds)) {
842 read,
"cannot delinearize lane ID for distribution");
844 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
851 additionalResults.push_back(read.getPadding());
852 additionalResultTypes.push_back(read.getPadding().getType());
854 bool hasMask =
false;
855 if (read.getMask()) {
865 read,
"non-trivial permutation maps not supported");
866 VectorType maskType =
867 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
868 additionalResults.push_back(read.getMask());
869 additionalResultTypes.push_back(maskType);
873 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
874 rewriter, warpOp, additionalResults, additionalResultTypes,
876 distributedVal = newWarpOp.getResult(operandIndex);
880 for (int64_t i = 0, e = indices.size(); i < e; ++i)
881 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
886 bindDims(read.getContext(), d0, d1);
887 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
890 unsigned indexPos = indexExpr.getPosition();
891 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
892 int64_t scale = distributedType.getDimSize(vectorPos);
894 rewriter, read.getLoc(), d0 + scale * d1,
895 {newIndices[indexPos], delinearizedIds[vectorPos]});
899 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
902 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
904 auto newRead = vector::TransferReadOp::create(
905 rewriter, read.getLoc(), distributedVal.
getType(), read.getBase(),
906 newIndices, read.getPermutationMapAttr(), newPadding, newMask,
907 read.getInBoundsAttr());
918 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
921 newResultTypes.reserve(warpOp->getNumResults());
923 newYieldValues.reserve(warpOp->getNumResults());
926 gpu::YieldOp yield = warpOp.getTerminator();
936 for (
OpResult result : warpOp.getResults()) {
937 Value yieldOperand = yield.getOperand(result.getResultNumber());
938 auto it = dedupYieldOperandPositionMap.insert(
939 std::make_pair(yieldOperand, newResultTypes.size()));
940 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
941 if (result.use_empty() || !it.second)
943 newResultTypes.push_back(result.getType());
944 newYieldValues.push_back(yieldOperand);
947 if (yield.getNumOperands() == newYieldValues.size())
950 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
951 rewriter, warpOp, newYieldValues, newResultTypes);
954 newWarpOp.getBody()->walk([&](
Operation *op) {
961 newValues.reserve(warpOp->getNumResults());
962 for (
OpResult result : warpOp.getResults()) {
963 if (result.use_empty())
964 newValues.push_back(
Value());
967 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
978 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
980 gpu::YieldOp yield = warpOp.getTerminator();
982 unsigned resultIndex;
983 for (
OpOperand &operand : yield->getOpOperands()) {
992 valForwarded = operand.
get();
996 auto arg = dyn_cast<BlockArgument>(operand.
get());
997 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
999 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1002 valForwarded = warpOperand;
1019 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1022 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1027 Location loc = broadcastOp.getLoc();
1029 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1030 Value broadcastSrc = broadcastOp.getSource();
1041 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1042 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1044 Value broadcasted = vector::BroadcastOp::create(
1045 rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1056 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1059 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1066 auto castDistributedType =
1067 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1068 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1069 VectorType castResultType = castDistributedType;
1073 unsigned castDistributedRank = castDistributedType.getRank();
1074 unsigned castOriginalRank = castOriginalType.getRank();
1075 if (castDistributedRank < castOriginalRank) {
1077 llvm::append_range(shape, castDistributedType.getShape());
1078 castDistributedType =
1083 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1084 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1087 Value newCast = vector::ShapeCastOp::create(
1088 rewriter, oldCastOp.getLoc(), castResultType,
1089 newWarpOp->getResult(newRetIndices[0]));
1115 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1118 getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1126 if (!llvm::all_of(mask->getOperands(), [&](
Value value) {
1127 return warpOp.isDefinedOutsideOfRegion(value);
1134 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1135 VectorType seqType = mask.getVectorType();
1143 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1144 warpOp.getWarpSize(), warpOp.getLaneid(),
1147 mask,
"cannot delinearize lane ID for distribution");
1148 assert(!delinearizedIds.empty());
1157 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1164 rewriter, loc, s1 - s0 * distShape[i],
1165 {delinearizedIds[i], mask.getOperand(i)});
1166 newOperands.push_back(maskDimIdx);
1170 vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1205 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1208 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1214 auto distributedType =
1215 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1218 if (distributedType.getRank() < 2)
1220 insertOp,
"result vector type must be 2D or higher");
1223 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1224 int64_t destDistributedDim =
1226 assert(destDistributedDim != -1 &&
"could not find distributed dimension");
1228 VectorType srcType = insertOp.getSourceVectorType();
1229 VectorType destType = insertOp.getDestVectorType();
1234 int64_t sourceDistributedDim =
1235 destDistributedDim - (destType.getRank() - srcType.getRank());
1236 if (sourceDistributedDim < 0)
1239 "distributed dimension must be in the last k dims of dest vector");
1241 if (srcType.getDimSize(sourceDistributedDim) !=
1242 destType.getDimSize(destDistributedDim))
1244 insertOp,
"distributed dimension must be fully inserted");
1246 insertOp.getSourceVectorType().getShape());
1247 newSourceDistShape[sourceDistributedDim] =
1248 distributedType.getDimSize(destDistributedDim);
1250 VectorType::get(newSourceDistShape, distributedType.getElementType());
1251 VectorType newDestTy = distributedType;
1253 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1254 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1255 {newSourceTy, newDestTy}, newRetIndices);
1257 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1258 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1261 Value newInsert = vector::InsertStridedSliceOp::create(
1262 rewriter, insertOp.getLoc(), distributedDest.
getType(),
1263 distributedSource, distributedDest, insertOp.getOffsets(),
1264 insertOp.getStrides());
1294 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1297 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1303 auto distributedType =
1304 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1307 if (distributedType.getRank() < 2)
1309 extractOp,
"result vector type must be 2D or higher");
1312 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1314 assert(distributedDim != -1 &&
"could not find distributed dimension");
1316 int64_t numOfExtractedDims =
1317 static_cast<int64_t
>(extractOp.getSizes().size());
1324 if (distributedDim < numOfExtractedDims) {
1325 int64_t distributedDimOffset =
1326 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1328 int64_t distributedDimSize =
1329 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1331 if (distributedDimOffset != 0 ||
1332 distributedDimSize != yieldedType.getDimSize(distributedDim))
1334 extractOp,
"distributed dimension must be fully extracted");
1337 extractOp.getSourceVectorType().getShape());
1338 newDistributedShape[distributedDim] =
1339 distributedType.getDimSize(distributedDim);
1340 auto newDistributedType =
1341 VectorType::get(newDistributedShape, distributedType.getElementType());
1343 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1344 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1348 extractOp.getSizes(), [](
Attribute attr) { return attr; });
1350 if (distributedDim <
static_cast<int64_t
>(distributedSizes.size()))
1352 distributedType.getDimSize(distributedDim));
1356 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1357 Value newExtract = vector::ExtractStridedSliceOp::create(
1358 rewriter, extractOp.getLoc(), distributedType, distributedVec,
1359 extractOp.getOffsets(),
1361 extractOp.getStrides());
1372 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1375 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1380 VectorType extractSrcType = extractOp.getSourceVectorType();
1384 if (extractSrcType.getRank() <= 1) {
1390 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1397 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1398 rewriter, warpOp, {extractOp.getSource()},
1399 {extractOp.getSourceVectorType()}, newRetIndices);
1401 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1403 Value newExtract = vector::ExtractOp::create(
1404 rewriter, loc, distributedVec, extractOp.getMixedPosition());
1411 auto distributedType =
1412 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1413 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1415 assert(distributedDim != -1 &&
"could not find distributed dimension");
1416 (void)distributedDim;
1420 for (
int i = 0; i < distributedType.getRank(); ++i)
1421 newDistributedShape[i + extractOp.getNumIndices()] =
1422 distributedType.getDimSize(i);
1423 auto newDistributedType =
1424 VectorType::get(newDistributedShape, distributedType.getElementType());
1426 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1427 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1430 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1432 Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1433 extractOp.getMixedPosition());
1443 WarpOpExtractScalar(
MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1446 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1449 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1454 VectorType extractSrcType = extractOp.getSourceVectorType();
1456 if (extractSrcType.getRank() > 1) {
1458 extractOp,
"only 0-D or 1-D source supported for now");
1462 if (!extractSrcType.getElementType().isF32() &&
1463 !extractSrcType.getElementType().isInteger(32))
1465 extractOp,
"only f32/i32 element types are supported");
1466 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1467 Type elType = extractSrcType.getElementType();
1468 VectorType distributedVecType;
1469 if (!is0dOrVec1Extract) {
1470 assert(extractSrcType.getRank() == 1 &&
1471 "expected that extract src rank is 0 or 1");
1472 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1474 int64_t elementsPerLane =
1475 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1478 distributedVecType = extractSrcType;
1483 additionalResults.append(
1485 additionalResultTypes.append(
1490 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1491 rewriter, warpOp, additionalResults, additionalResultTypes,
1494 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1498 if (is0dOrVec1Extract) {
1502 vector::ExtractOp::create(rewriter, loc, distributedVec, indices);
1508 int64_t staticPos = extractOp.getStaticPosition()[0];
1510 ? (newWarpOp->getResult(newRetIndices[1]))
1514 int64_t elementsPerLane = distributedVecType.getShape()[0];
1518 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1521 elementsPerLane == 1
1524 sym0 % elementsPerLane, pos);
1526 vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1529 Value shuffled = warpShuffleFromIdxFn(
1530 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1536 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1543 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1545 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1550 VectorType vecType = insertOp.getDestVectorType();
1551 VectorType distrType =
1552 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1555 if (vecType.getRank() > 1) {
1557 insertOp,
"only 0-D or 1-D source supported for now");
1562 insertOp.getValueToStore()};
1564 distrType, insertOp.getValueToStore().getType()};
1566 additionalResultTypes.append(
1571 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1572 rewriter, warpOp, additionalResults, additionalResultTypes,
1575 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1576 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1580 if (vecType.getRank() != 0) {
1581 int64_t staticPos = insertOp.getStaticPosition()[0];
1582 pos = ShapedType::isDynamic(staticPos)
1583 ? (newWarpOp->getResult(newRetIndices[2]))
1588 if (vecType == distrType) {
1592 indices.push_back(pos);
1594 newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1595 distributedVec, indices);
1603 int64_t elementsPerLane = distrType.getShape()[0];
1607 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1610 rewriter, loc, sym0 % elementsPerLane, pos);
1611 Value isInsertingLane =
1612 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1613 newWarpOp.getLaneid(), insertingLane);
1616 rewriter, loc, isInsertingLane,
1619 Value newInsert = vector::InsertOp::create(
1620 builder, loc, newSource, distributedVec, newPos);
1621 scf::YieldOp::create(builder, loc, newInsert);
1625 scf::YieldOp::create(builder, loc, distributedVec);
1635 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1637 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1645 if (insertOp.getDestVectorType().getRank() <= 1) {
1651 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1655 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1656 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1657 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1660 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1661 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1662 Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1664 insertOp.getMixedPosition());
1671 auto distrDestType =
1672 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1673 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1674 int64_t distrDestDim = -1;
1675 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1676 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1679 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1683 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1686 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1694 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1695 if (distrSrcDim >= 0)
1696 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1702 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1703 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1704 {distrSrcType, distrDestType}, newRetIndices);
1706 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1707 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1711 if (distrSrcDim >= 0) {
1713 newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1715 insertOp.getMixedPosition());
1718 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1723 rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1724 Value isInsertingLane =
1725 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1726 newWarpOp.getLaneid(), insertingLane);
1728 newPos[distrDestDim] %= elementsPerLane;
1730 Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1731 distributedDest, newPos);
1732 scf::YieldOp::create(builder, loc, newInsert);
1735 scf::YieldOp::create(builder, loc, distributedDest);
1737 newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1739 nonInsertingBuilder)
1777 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1779 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1781 Operation *lastNode = warpOpYield->getPrevNode();
1782 auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1795 llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1796 llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1797 for (
OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1800 nonIfYieldValues.push_back(yieldOperand.
get());
1801 nonIfYieldIndices.push_back(yieldOperandIdx);
1804 OpResult ifResult = cast<OpResult>(yieldOperand.
get());
1806 ifResultMapping[yieldOperandIdx] = ifResultIdx;
1809 if (!isa<VectorType>(ifResult.
getType()))
1811 VectorType distType =
1812 cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1813 ifResultDistTypes[ifResultIdx] = distType;
1818 auto [escapingValuesThen, escapingValueInputTypesThen,
1819 escapingValueDistTypesThen] =
1820 getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1822 auto [escapingValuesElse, escapingValueInputTypesElse,
1823 escapingValueDistTypesElse] =
1824 getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1826 if (llvm::is_contained(escapingValueDistTypesThen,
Type{}) ||
1827 llvm::is_contained(escapingValueDistTypesElse,
Type{}))
1836 newWarpOpYieldValues.append(escapingValuesThen.begin(),
1837 escapingValuesThen.end());
1838 newWarpOpYieldValues.append(escapingValuesElse.begin(),
1839 escapingValuesElse.end());
1841 newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1842 escapingValueDistTypesThen.end());
1843 newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1844 escapingValueDistTypesElse.end());
1846 llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
1847 for (
auto [idx, val] :
1848 llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1849 origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
1850 newWarpOpYieldValues.push_back(val);
1851 newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1854 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1855 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1859 Type distType = cast<Value>(res).getType();
1860 if (
auto vecType = dyn_cast<VectorType>(distType)) {
1861 AffineMap map = distributionMapFn(cast<Value>(res));
1863 distType = ifResultDistTypes.count(i)
1864 ? ifResultDistTypes[i]
1865 : getDistributedType(vecType, map, warpOp.getWarpSize());
1867 newIfOpDistResTypes.push_back(distType);
1872 auto newIfOp = scf::IfOp::create(
1873 rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
1874 static_cast<bool>(ifOp.thenBlock()),
1875 static_cast<bool>(ifOp.elseBlock()));
1876 auto encloseRegionInWarpOp =
1878 llvm::SmallSetVector<Value, 32> &escapingValues,
1880 size_t warpResRangeStart) {
1885 llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1888 for (
size_t i = 0; i < escapingValues.size();
1889 ++i, ++warpResRangeStart) {
1890 innerWarpInputVals.push_back(
1891 newWarpOp.getResult(warpResRangeStart));
1892 escapeValToBlockArgIndex[escapingValues[i]] =
1893 innerWarpInputTypes.size();
1894 innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1896 auto innerWarp = WarpExecuteOnLane0Op::create(
1897 rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1898 newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1899 innerWarpInputVals, innerWarpInputTypes);
1901 innerWarp.getWarpRegion().takeBody(*oldIfBranch->
getParent());
1902 innerWarp.getWarpRegion().addArguments(
1903 innerWarpInputTypes,
1908 yieldOperands.push_back(operand);
1912 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1914 scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1920 auto it = escapeValToBlockArgIndex.find(operand.get());
1921 if (it == escapeValToBlockArgIndex.end())
1923 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1926 mlir::vector::moveScalarUniformCode(innerWarp);
1928 encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
1929 &newIfOp.getThenRegion().front(), escapingValuesThen,
1930 escapingValueInputTypesThen, 1);
1931 if (!ifOp.getElseRegion().empty())
1932 encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
1933 &newIfOp.getElseRegion().front(),
1934 escapingValuesElse, escapingValueInputTypesElse,
1935 1 + escapingValuesThen.size());
1938 for (
auto [origIdx, newIdx] : ifResultMapping)
1940 newIfOp.getResult(newIdx), newIfOp);
1943 for (
auto [origIdx, newIdx] : origToNewYieldIdx)
1945 newWarpOp.getResult(newIdx));
1993 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1995 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1997 Operation *lastNode = warpOpYield->getPrevNode();
1998 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
2003 auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2004 getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2006 if (llvm::is_contained(escapingValueDistTypes,
Type{}))
2019 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
2020 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
2021 for (
OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
2024 nonForYieldedValues.push_back(yieldOperand.
get());
2028 OpResult forResult = cast<OpResult>(yieldOperand.
get());
2033 if (!isa<VectorType>(forResult.
getType()))
2035 VectorType distType = cast<VectorType>(
2037 forResultDistTypes[forResultNumber] = distType;
2047 newWarpOpYieldValues.push_back(initArg);
2049 Type distType = initArg.getType();
2050 if (
auto vecType = dyn_cast<VectorType>(distType)) {
2054 AffineMap map = distributionMapFn(initArg);
2055 distType = forResultDistTypes.count(i)
2056 ? forResultDistTypes[i]
2057 : getDistributedType(vecType, map, warpOp.getWarpSize());
2059 newWarpOpDistTypes.push_back(distType);
2062 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
2063 escapingValues.begin(), escapingValues.end());
2064 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2065 escapingValueDistTypes.begin(),
2066 escapingValueDistTypes.end());
2071 llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
2073 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2074 nonForResultMapping[i] = newWarpOpYieldValues.size();
2075 newWarpOpYieldValues.push_back(v);
2076 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
2079 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
2080 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
2084 const unsigned escapingValuesStartIdx =
2085 forOp.getInitArgs().size();
2088 for (
size_t i = 0; i < escapingValuesStartIdx; ++i)
2089 newForOpOperands.push_back(newWarpOp.getResult(i));
2094 auto newForOp = scf::ForOp::create(
2095 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
2096 forOp.getStep(), newForOpOperands,
nullptr,
2097 forOp.getUnsignedCmp());
2104 newForOp.getRegionIterArgs().end());
2106 forOp.getResultTypes().end());
2110 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
2111 for (
size_t i = escapingValuesStartIdx;
2112 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2113 innerWarpInput.push_back(newWarpOp.getResult(i));
2114 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
2115 innerWarpInputType.size();
2116 innerWarpInputType.push_back(
2117 escapingValueInputTypes[i - escapingValuesStartIdx]);
2120 auto innerWarp = WarpExecuteOnLane0Op::create(
2121 rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
2122 newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
2123 innerWarpInputType);
2127 argMapping.push_back(newForOp.getInductionVar());
2128 for (
Value args : innerWarp.getBody()->getArguments())
2129 argMapping.push_back(args);
2131 argMapping.resize(forOp.getBody()->getNumArguments());
2133 for (
Value operand : forOp.getBody()->getTerminator()->getOperands())
2134 yieldOperands.push_back(operand);
2136 rewriter.
eraseOp(forOp.getBody()->getTerminator());
2137 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
2142 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
2146 if (!innerWarp.getResults().empty())
2147 scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
2151 for (
auto [origIdx, newIdx] : forResultMapping)
2153 newForOp.getResult(newIdx), newForOp);
2156 for (
auto [origIdx, newIdx] : nonForResultMapping)
2158 newWarpOp.getResult(newIdx));
2167 auto it = argIndexMapping.find(operand.get());
2168 if (it == argIndexMapping.end())
2170 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
2175 mlir::vector::moveScalarUniformCode(innerWarp);
2204 DistributedReductionFn distributedReductionFn,
2207 distributedReductionFn(std::move(distributedReductionFn)) {}
2209 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2212 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
2218 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
2220 if (vectorType.getRank() != 1)
2222 warpOp,
"Only rank 1 reductions can be distributed.");
2224 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2226 warpOp,
"Reduction vector dimension must match was size.");
2227 if (!reductionOp.getType().isIntOrFloat())
2229 warpOp,
"Reduction distribution currently only supports floats and "
2232 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2238 if (reductionOp.getAcc()) {
2239 yieldValues.push_back(reductionOp.getAcc());
2240 retTypes.push_back(reductionOp.getAcc().getType());
2243 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2244 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2248 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2251 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2252 reductionOp.getKind(), newWarpOp.getWarpSize());
2253 if (reductionOp.getAcc()) {
2255 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2256 newWarpOp.getResult(newRetIndices[1]));
2263 DistributedReductionFn distributedReductionFn;
2274 void mlir::vector::populateDistributeTransferWriteOpPatterns(
2277 patterns.add<WarpOpTransferWrite>(
patterns.getContext(), distributionMapFn,
2278 maxNumElementsToExtract, benefit);
2281 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2283 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
2287 .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2288 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
2289 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
2290 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2292 patterns.add<WarpOpExtractScalar>(
patterns.getContext(), warpShuffleFromIdxFn,
2300 void mlir::vector::populateDistributeReduction(
2302 const DistributedReductionFn &distributedReductionFn,
2304 patterns.add<WarpOpReduction>(
patterns.getContext(), distributedReductionFn,
2311 return llvm::all_of(op->
getOperands(), definedOutside) &&
2315 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2316 Block *body = warpOp.getBody();
2319 llvm::SmallSetVector<Operation *, 8> opsToMove;
2322 auto isDefinedOutsideOfBody = [&](
Value value) {
2324 return (definingOp && opsToMove.count(definingOp)) ||
2325 warpOp.isDefinedOutsideOfRegion(value);
2331 bool hasVectorResult = llvm::any_of(op.
getResults(), [](
Value result) {
2332 return isa<VectorType>(result.getType());
2334 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
2335 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.