22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVectorExtras.h"
24 #include "llvm/Support/FormatVariadic.h"
43 VectorType distributedType) {
49 for (
unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
54 distributedType.getContext());
62 VectorType distributedType) {
63 assert(sequentialType.getRank() == distributedType.getRank() &&
64 "sequential and distributed vector types must have the same rank");
65 int64_t distributedDim = -1;
66 for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
67 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
70 assert(distributedDim == -1 &&
"found multiple distributed dims");
74 return distributedDim;
84 struct DistributedLoadStoreHelper {
85 DistributedLoadStoreHelper(
Value sequentialVal,
Value distributedVal,
87 : sequentialVal(sequentialVal), distributedVal(distributedVal),
88 laneId(laneId), zero(zero) {
89 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.
getType());
90 distributedVectorType = dyn_cast<VectorType>(distributedVal.
getType());
91 if (sequentialVectorType && distributedVectorType)
97 int64_t distributedSize = distributedVectorType.getDimSize(index);
99 return b.
createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
112 assert((val == distributedVal || val == sequentialVal) &&
113 "Must store either the preregistered distributed or the "
114 "preregistered sequential value.");
116 if (!isa<VectorType>(val.
getType()))
117 return memref::StoreOp::create(b, loc, val, buffer, zero);
121 int64_t rank = sequentialVectorType.getRank();
123 if (val == distributedVal) {
124 for (
auto dimExpr : distributionMap.getResults()) {
125 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126 indices[index] = buildDistributedOffset(b, loc, index);
130 return vector::TransferWriteOp::create(
131 b, loc, val, buffer, indices,
158 if (!isa<VectorType>(type))
159 return memref::LoadOp::create(b, loc, buffer, zero);
164 assert((type == distributedVectorType || type == sequentialVectorType) &&
165 "Must store either the preregistered distributed or the "
166 "preregistered sequential type.");
168 if (type == distributedVectorType) {
169 for (
auto dimExpr : distributionMap.getResults()) {
170 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171 indices[index] = buildDistributedOffset(b, loc, index);
175 return vector::TransferReadOp::create(
176 b, loc, cast<VectorType>(type), buffer, indices,
181 Value sequentialVal, distributedVal, laneId, zero;
182 VectorType sequentialVectorType, distributedVectorType;
196 return rewriter.
create(res);
235 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
237 assert(warpOp.getBodyRegion().hasOneBlock() &&
238 "expected WarpOp with single block");
239 Block *warpOpBody = &warpOp.getBodyRegion().
front();
248 Value isLane0 = arith::CmpIOp::create(
249 rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
250 auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
252 rewriter.
eraseOp(ifOp.thenBlock()->getTerminator());
259 Value distributedVal = it.value();
260 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
261 warpOp.getLaneid(), c0);
265 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
268 helper.buildStore(rewriter, loc, distributedVal, buffer);
271 bbArgReplacements.push_back(
272 helper.buildLoad(rewriter, loc, sequentialVal.
getType(), buffer));
276 if (!warpOp.getArgs().empty()) {
278 options.warpSyncronizationFn(loc, rewriter, warpOp);
282 rewriter.
mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
289 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
290 Location yieldLoc = yieldOp.getLoc();
292 Value sequentialVal = it.value();
293 Value distributedVal = warpOp->getResult(it.index());
294 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
295 warpOp.getLaneid(), c0);
299 Value buffer =
options.warpAllocationFn(loc, rewriter, warpOp,
305 helper.buildStore(rewriter, loc, sequentialVal, buffer);
316 replacements.push_back(
317 helper.buildLoad(rewriter, loc, distributedVal.
getType(), buffer));
321 if (!yieldOp.getOperands().empty()) {
323 options.warpSyncronizationFn(loc, rewriter, warpOp);
329 scf::YieldOp::create(rewriter, yieldLoc);
332 rewriter.
replaceOp(warpOp, replacements);
349 static VectorType getDistributedType(VectorType originalType,
AffineMap map,
354 if (targetShape[position] % warpSize != 0) {
355 if (warpSize % targetShape[position] != 0) {
358 warpSize /= targetShape[position];
359 targetShape[position] = 1;
362 targetShape[position] = targetShape[position] / warpSize;
369 VectorType targetType =
397 maxNumElementsToExtract(maxNumElementsToExtract) {}
402 vector::TransferWriteOp writeOp,
403 WarpExecuteOnLane0Op warpOp)
const {
404 VectorType writtenVectorType = writeOp.getVectorType();
408 if (writtenVectorType.getRank() == 0)
412 AffineMap map = distributionMapFn(writeOp.getVector());
413 VectorType targetType =
414 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
420 if (writeOp.getMask()) {
427 if (!writeOp.getPermutationMap().isMinorIdentity())
430 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
435 vector::TransferWriteOp newWriteOp =
436 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
440 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
447 for (
auto [seqSize, distSize] :
448 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
449 assert(seqSize % distSize == 0 &&
"Invalid distributed vector shape");
450 delinearizedIdSizes.push_back(rewriter.
getIndexAttr(seqSize / distSize));
454 delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
455 rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
461 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
467 newWriteOp.getIndices().end());
470 bindDims(newWarpOp.getContext(), d0, d1);
471 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
474 unsigned indexPos = indexExpr.getPosition();
475 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
476 Value laneId = delinearized[vectorPos];
480 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
482 newWriteOp.getIndicesMutable().assign(indices);
489 vector::TransferWriteOp writeOp,
490 WarpExecuteOnLane0Op warpOp)
const {
492 VectorType vecType = writeOp.getVectorType();
494 if (vecType.getNumElements() > maxNumElementsToExtract) {
498 "writes more elements ({0}) than allowed to extract ({1})",
499 vecType.getNumElements(), maxNumElementsToExtract));
503 if (llvm::all_of(warpOp.getOps(),
504 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
510 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
511 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
515 auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc,
TypeRange(),
516 newWarpOp.getLaneid(),
517 newWarpOp.getWarpSize());
518 Block &body = secondWarpOp.getBodyRegion().
front();
521 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
522 newWriteOp.getValueToStoreMutable().assign(
523 newWarpOp.getResult(newRetIndices[0]));
525 gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
529 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
531 gpu::YieldOp yield = warpOp.getTerminator();
532 Operation *lastNode = yield->getPrevNode();
533 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
537 Value maybeMask = writeOp.getMask();
538 if (!llvm::all_of(writeOp->getOperands(), [&](
Value value) {
539 return writeOp.getVector() == value ||
540 (maybeMask && maybeMask == value) ||
541 warpOp.isDefinedOutsideOfRegion(value);
545 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
549 if (writeOp.getMask())
552 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
562 vector::TransferWriteOp cloneWriteOp(
RewriterBase &rewriter,
563 WarpExecuteOnLane0Op warpOp,
564 vector::TransferWriteOp writeOp,
565 VectorType targetType,
566 VectorType maybeMaskType)
const {
567 assert(writeOp->getParentOp() == warpOp &&
568 "write must be nested immediately under warp");
571 WarpExecuteOnLane0Op newWarpOp;
573 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
574 rewriter, warpOp,
ValueRange{writeOp.getVector(), writeOp.getMask()},
575 TypeRange{targetType, maybeMaskType}, newRetIndices);
577 newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
578 rewriter, warpOp,
ValueRange{{writeOp.getVector()}},
583 cast<vector::TransferWriteOp>(rewriter.
clone(*writeOp.getOperation()));
585 newWriteOp.getValueToStoreMutable().assign(
586 newWarpOp.getResult(newRetIndices[0]));
588 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
593 unsigned maxNumElementsToExtract = 1;
616 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
626 Value distributedVal = warpOp.getResult(operandIndex);
632 if (
auto vecType = dyn_cast<VectorType>(distributedVal.
getType())) {
634 auto operandType = cast<VectorType>(operand.get().getType());
638 auto operandType = operand.get().getType();
639 assert(!isa<VectorType>(operandType) &&
640 "unexpected yield of vector from op with scalar result type");
641 targetType = operandType;
643 retTypes.push_back(targetType);
644 yieldValues.push_back(operand.get());
647 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
648 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
652 for (
unsigned i : llvm::seq(
unsigned(0), elementWise->
getNumOperands())) {
653 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
658 rewriter, loc, elementWise, newOperands,
659 {newWarpOp.getResult(operandIndex).
getType()});
682 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
685 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
689 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
698 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
701 Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
730 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
733 getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
739 if (resTy.getNumElements() !=
static_cast<int64_t
>(warpOp.getWarpSize()))
742 llvm::formatv(
"Expected result size ({0}) to be of warp size ({1})",
743 resTy.getNumElements(), warpOp.getWarpSize()));
744 VectorType newVecTy =
745 cast<VectorType>(warpOp.getResult(operandIdx).getType());
747 Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
748 newVecTy, warpOp.getLaneid());
774 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
781 return isa<vector::TransferReadOp>(op) && op->
hasOneUse();
785 warpOp,
"warp result is not a vector.transfer_read op");
789 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
791 read,
"source must be defined outside of the region");
794 Value distributedVal = warpOp.getResult(operandIndex);
797 read.getIndices().end());
798 auto sequentialType = cast<VectorType>(read.getResult().getType());
799 auto distributedType = cast<VectorType>(distributedVal.
getType());
806 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
807 distributedType.getShape(), warpOp.getWarpSize(),
808 warpOp.getLaneid(), delinearizedIds)) {
810 read,
"cannot delinearize lane ID for distribution");
812 assert(!delinearizedIds.empty() || map.
getNumResults() == 0);
819 additionalResults.push_back(read.getPadding());
820 additionalResultTypes.push_back(read.getPadding().getType());
822 bool hasMask =
false;
823 if (read.getMask()) {
833 read,
"non-trivial permutation maps not supported");
834 VectorType maskType =
835 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
836 additionalResults.push_back(read.getMask());
837 additionalResultTypes.push_back(maskType);
841 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
842 rewriter, warpOp, additionalResults, additionalResultTypes,
844 distributedVal = newWarpOp.getResult(operandIndex);
848 for (int64_t i = 0, e = indices.size(); i < e; ++i)
849 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
854 bindDims(read.getContext(), d0, d1);
855 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
858 unsigned indexPos = indexExpr.getPosition();
859 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
860 int64_t scale = distributedType.getDimSize(vectorPos);
862 rewriter, read.getLoc(), d0 + scale * d1,
863 {newIndices[indexPos], delinearizedIds[vectorPos]});
867 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
870 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
872 auto newRead = vector::TransferReadOp::create(
873 rewriter, read.getLoc(), distributedVal.
getType(), read.getBase(),
874 newIndices, read.getPermutationMapAttr(), newPadding, newMask,
875 read.getInBoundsAttr());
886 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
889 newResultTypes.reserve(warpOp->getNumResults());
891 newYieldValues.reserve(warpOp->getNumResults());
894 gpu::YieldOp yield = warpOp.getTerminator();
904 for (
OpResult result : warpOp.getResults()) {
905 Value yieldOperand = yield.getOperand(result.getResultNumber());
906 auto it = dedupYieldOperandPositionMap.insert(
907 std::make_pair(yieldOperand, newResultTypes.size()));
908 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
909 if (result.use_empty() || !it.second)
911 newResultTypes.push_back(result.getType());
912 newYieldValues.push_back(yieldOperand);
915 if (yield.getNumOperands() == newYieldValues.size())
918 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
919 rewriter, warpOp, newYieldValues, newResultTypes);
922 newWarpOp.getBody()->walk([&](
Operation *op) {
929 newValues.reserve(warpOp->getNumResults());
930 for (
OpResult result : warpOp.getResults()) {
931 if (result.use_empty())
932 newValues.push_back(
Value());
935 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
946 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
948 gpu::YieldOp yield = warpOp.getTerminator();
950 unsigned resultIndex;
951 for (
OpOperand &operand : yield->getOpOperands()) {
960 valForwarded = operand.
get();
964 auto arg = dyn_cast<BlockArgument>(operand.
get());
965 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
967 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
970 valForwarded = warpOperand;
987 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
990 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
995 Location loc = broadcastOp.getLoc();
997 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
998 Value broadcastSrc = broadcastOp.getSource();
1009 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1010 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1012 Value broadcasted = vector::BroadcastOp::create(
1013 rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1024 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1027 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1034 auto castDistributedType =
1035 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1036 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1037 VectorType castResultType = castDistributedType;
1041 unsigned castDistributedRank = castDistributedType.getRank();
1042 unsigned castOriginalRank = castOriginalType.getRank();
1043 if (castDistributedRank < castOriginalRank) {
1045 llvm::append_range(shape, castDistributedType.getShape());
1046 castDistributedType =
1051 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1052 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1055 Value newCast = vector::ShapeCastOp::create(
1056 rewriter, oldCastOp.getLoc(), castResultType,
1057 newWarpOp->getResult(newRetIndices[0]));
1083 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1086 getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1094 if (!llvm::all_of(mask->getOperands(), [&](
Value value) {
1095 return warpOp.isDefinedOutsideOfRegion(value);
1102 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1103 VectorType seqType = mask.getVectorType();
1111 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1112 warpOp.getWarpSize(), warpOp.getLaneid(),
1115 mask,
"cannot delinearize lane ID for distribution");
1116 assert(!delinearizedIds.empty());
1125 for (
int i = 0, e = distShape.size(); i < e; ++i) {
1132 rewriter, loc, s1 - s0 * distShape[i],
1133 {delinearizedIds[i], mask.getOperand(i)});
1134 newOperands.push_back(maskDimIdx);
1138 vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1173 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1176 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1182 auto distributedType =
1183 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1186 if (distributedType.getRank() < 2)
1188 insertOp,
"result vector type must be 2D or higher");
1191 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1192 int64_t destDistributedDim =
1194 assert(destDistributedDim != -1 &&
"could not find distributed dimension");
1196 VectorType srcType = insertOp.getSourceVectorType();
1197 VectorType destType = insertOp.getDestVectorType();
1202 int64_t sourceDistributedDim =
1203 destDistributedDim - (destType.getRank() - srcType.getRank());
1204 if (sourceDistributedDim < 0)
1207 "distributed dimension must be in the last k dims of dest vector");
1209 if (srcType.getDimSize(sourceDistributedDim) !=
1210 destType.getDimSize(destDistributedDim))
1212 insertOp,
"distributed dimension must be fully inserted");
1214 insertOp.getSourceVectorType().getShape());
1215 newSourceDistShape[sourceDistributedDim] =
1216 distributedType.getDimSize(destDistributedDim);
1218 VectorType::get(newSourceDistShape, distributedType.getElementType());
1219 VectorType newDestTy = distributedType;
1221 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1222 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1223 {newSourceTy, newDestTy}, newRetIndices);
1225 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1226 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1229 Value newInsert = vector::InsertStridedSliceOp::create(
1230 rewriter, insertOp.getLoc(), distributedDest.
getType(),
1231 distributedSource, distributedDest, insertOp.getOffsets(),
1232 insertOp.getStrides());
1262 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1265 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1271 auto distributedType =
1272 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1275 if (distributedType.getRank() < 2)
1277 extractOp,
"result vector type must be 2D or higher");
1280 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1282 assert(distributedDim != -1 &&
"could not find distributed dimension");
1284 int64_t numOfExtractedDims =
1285 static_cast<int64_t
>(extractOp.getSizes().size());
1292 if (distributedDim < numOfExtractedDims) {
1293 int64_t distributedDimOffset =
1294 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1296 int64_t distributedDimSize =
1297 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1299 if (distributedDimOffset != 0 ||
1300 distributedDimSize != yieldedType.getDimSize(distributedDim))
1302 extractOp,
"distributed dimension must be fully extracted");
1305 extractOp.getSourceVectorType().getShape());
1306 newDistributedShape[distributedDim] =
1307 distributedType.getDimSize(distributedDim);
1308 auto newDistributedType =
1309 VectorType::get(newDistributedShape, distributedType.getElementType());
1311 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1312 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1316 extractOp.getSizes(), [](
Attribute attr) { return attr; });
1318 if (distributedDim <
static_cast<int64_t
>(distributedSizes.size()))
1320 distributedType.getDimSize(distributedDim));
1324 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1325 Value newExtract = vector::ExtractStridedSliceOp::create(
1326 rewriter, extractOp.getLoc(), distributedType, distributedVec,
1327 extractOp.getOffsets(),
1329 extractOp.getStrides());
1340 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1343 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1348 VectorType extractSrcType = extractOp.getSourceVectorType();
1352 if (extractSrcType.getRank() <= 1) {
1358 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1365 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1366 rewriter, warpOp, {extractOp.getVector()},
1367 {extractOp.getSourceVectorType()}, newRetIndices);
1369 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1371 Value newExtract = vector::ExtractOp::create(
1372 rewriter, loc, distributedVec, extractOp.getMixedPosition());
1379 auto distributedType =
1380 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1381 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1383 assert(distributedDim != -1 &&
"could not find distributed dimension");
1384 (void)distributedDim;
1388 for (
int i = 0; i < distributedType.getRank(); ++i)
1389 newDistributedShape[i + extractOp.getNumIndices()] =
1390 distributedType.getDimSize(i);
1391 auto newDistributedType =
1392 VectorType::get(newDistributedShape, distributedType.getElementType());
1394 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1395 rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1398 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1400 Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1401 extractOp.getMixedPosition());
1411 WarpOpExtractScalar(
MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1414 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1417 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1422 VectorType extractSrcType = extractOp.getSourceVectorType();
1424 if (extractSrcType.getRank() > 1) {
1426 extractOp,
"only 0-D or 1-D source supported for now");
1430 if (!extractSrcType.getElementType().isF32() &&
1431 !extractSrcType.getElementType().isInteger(32))
1433 extractOp,
"only f32/i32 element types are supported");
1434 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1435 Type elType = extractSrcType.getElementType();
1436 VectorType distributedVecType;
1437 if (!is0dOrVec1Extract) {
1438 assert(extractSrcType.getRank() == 1 &&
1439 "expected that extract src rank is 0 or 1");
1440 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1442 int64_t elementsPerLane =
1443 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1446 distributedVecType = extractSrcType;
1451 additionalResults.append(
1453 additionalResultTypes.append(
1458 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1459 rewriter, warpOp, additionalResults, additionalResultTypes,
1462 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1466 if (is0dOrVec1Extract) {
1470 vector::ExtractOp::create(rewriter, loc, distributedVec, indices);
1476 int64_t staticPos = extractOp.getStaticPosition()[0];
1478 ? (newWarpOp->getResult(newRetIndices[1]))
1482 int64_t elementsPerLane = distributedVecType.getShape()[0];
1486 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1489 elementsPerLane == 1
1492 sym0 % elementsPerLane, pos);
1494 vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1497 Value shuffled = warpShuffleFromIdxFn(
1498 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1504 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1511 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1513 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1518 VectorType vecType = insertOp.getDestVectorType();
1519 VectorType distrType =
1520 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1523 if (vecType.getRank() > 1) {
1525 insertOp,
"only 0-D or 1-D source supported for now");
1530 insertOp.getValueToStore()};
1532 distrType, insertOp.getValueToStore().getType()};
1534 additionalResultTypes.append(
1539 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1540 rewriter, warpOp, additionalResults, additionalResultTypes,
1543 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1544 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1548 if (vecType.getRank() != 0) {
1549 int64_t staticPos = insertOp.getStaticPosition()[0];
1550 pos = ShapedType::isDynamic(staticPos)
1551 ? (newWarpOp->getResult(newRetIndices[2]))
1556 if (vecType == distrType) {
1560 indices.push_back(pos);
1562 newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1563 distributedVec, indices);
1571 int64_t elementsPerLane = distrType.getShape()[0];
1575 rewriter, loc, sym0.
ceilDiv(elementsPerLane), pos);
1578 rewriter, loc, sym0 % elementsPerLane, pos);
1579 Value isInsertingLane =
1580 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1581 newWarpOp.getLaneid(), insertingLane);
1584 rewriter, loc, isInsertingLane,
1587 Value newInsert = vector::InsertOp::create(
1588 builder, loc, newSource, distributedVec, newPos);
1589 scf::YieldOp::create(builder, loc, newInsert);
1593 scf::YieldOp::create(builder, loc, distributedVec);
1603 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1605 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1613 if (insertOp.getDestVectorType().getRank() <= 1) {
1619 if (warpOp.getResult(operandNumber).getType() == operand->
get().
getType()) {
1623 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1624 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1625 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1628 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1629 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1630 Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1632 insertOp.getMixedPosition());
1639 auto distrDestType =
1640 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1641 auto yieldedType = cast<VectorType>(operand->
get().
getType());
1642 int64_t distrDestDim = -1;
1643 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1644 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1647 assert(distrDestDim == -1 &&
"found multiple distributed dims");
1651 assert(distrDestDim != -1 &&
"could not find distributed dimension");
1654 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1662 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1663 if (distrSrcDim >= 0)
1664 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1670 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1671 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1672 {distrSrcType, distrDestType}, newRetIndices);
1674 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1675 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1679 if (distrSrcDim >= 0) {
1681 newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1683 insertOp.getMixedPosition());
1686 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1691 rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1692 Value isInsertingLane =
1693 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1694 newWarpOp.getLaneid(), insertingLane);
1696 newPos[distrDestDim] %= elementsPerLane;
1698 Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1699 distributedDest, newPos);
1700 scf::YieldOp::create(builder, loc, newInsert);
1703 scf::YieldOp::create(builder, loc, distributedDest);
1705 newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1707 nonInsertingBuilder)
1752 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1754 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1756 Operation *lastNode = warpOpYield->getPrevNode();
1757 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1762 llvm::SmallSetVector<Value, 32> escapingValues;
1766 forOp.getBodyRegion(), [&](
OpOperand *operand) {
1767 Operation *parent = operand->get().getParentRegion()->getParentOp();
1768 if (warpOp->isAncestor(parent)) {
1769 if (!escapingValues.insert(operand->get()))
1771 Type distType = operand->get().getType();
1772 if (auto vecType = dyn_cast<VectorType>(distType)) {
1773 AffineMap map = distributionMapFn(operand->get());
1774 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1776 escapingValueInputTypes.push_back(operand->get().getType());
1777 escapingValueDistTypes.push_back(distType);
1781 if (llvm::is_contained(escapingValueDistTypes,
Type{}))
1794 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
1795 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
1796 for (
OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1799 nonForYieldedValues.push_back(yieldOperand.
get());
1803 OpResult forResult = cast<OpResult>(yieldOperand.
get());
1808 if (!isa<VectorType>(forResult.
getType()))
1810 VectorType distType = cast<VectorType>(
1812 forResultDistTypes[forResultNumber] = distType;
1822 newWarpOpYieldValues.push_back(initArg);
1824 Type distType = initArg.getType();
1825 if (
auto vecType = dyn_cast<VectorType>(distType)) {
1829 AffineMap map = distributionMapFn(initArg);
1830 distType = forResultDistTypes.count(i)
1831 ? forResultDistTypes[i]
1832 : getDistributedType(vecType, map, warpOp.getWarpSize());
1834 newWarpOpDistTypes.push_back(distType);
1837 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
1838 escapingValues.begin(), escapingValues.end());
1839 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
1840 escapingValueDistTypes.begin(),
1841 escapingValueDistTypes.end());
1846 llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
1848 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
1849 nonForResultMapping[i] = newWarpOpYieldValues.size();
1850 newWarpOpYieldValues.push_back(v);
1851 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
1854 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1855 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1859 const unsigned escapingValuesStartIdx =
1860 forOp.getInitArgs().size();
1863 for (
size_t i = 0; i < escapingValuesStartIdx; ++i)
1864 newForOpOperands.push_back(newWarpOp.getResult(i));
1869 auto newForOp = scf::ForOp::create(
1870 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1871 forOp.getStep(), newForOpOperands,
nullptr,
1872 forOp.getUnsignedCmp());
1879 newForOp.getRegionIterArgs().end());
1881 forOp.getResultTypes().end());
1885 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1886 for (
size_t i = escapingValuesStartIdx;
1887 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
1888 innerWarpInput.push_back(newWarpOp.getResult(i));
1889 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1890 innerWarpInputType.size();
1891 innerWarpInputType.push_back(
1892 escapingValueInputTypes[i - escapingValuesStartIdx]);
1895 auto innerWarp = WarpExecuteOnLane0Op::create(
1896 rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
1897 newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
1898 innerWarpInputType);
1902 argMapping.push_back(newForOp.getInductionVar());
1903 for (
Value args : innerWarp.getBody()->getArguments())
1904 argMapping.push_back(args);
1906 argMapping.resize(forOp.getBody()->getNumArguments());
1908 for (
Value operand : forOp.getBody()->getTerminator()->getOperands())
1909 yieldOperands.push_back(operand);
1911 rewriter.
eraseOp(forOp.getBody()->getTerminator());
1912 rewriter.
mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1917 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1921 if (!innerWarp.getResults().empty())
1922 scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
1926 for (
auto [origIdx, newIdx] : forResultMapping)
1928 newForOp.getResult(newIdx), newForOp);
1931 for (
auto [origIdx, newIdx] : nonForResultMapping)
1933 newWarpOp.getResult(newIdx));
1942 auto it = argIndexMapping.find(operand.get());
1943 if (it == argIndexMapping.end())
1945 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1950 mlir::vector::moveScalarUniformCode(innerWarp);
1979 DistributedReductionFn distributedReductionFn,
1982 distributedReductionFn(std::move(distributedReductionFn)) {}
1984 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1987 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1993 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1995 if (vectorType.getRank() != 1)
1997 warpOp,
"Only rank 1 reductions can be distributed.");
1999 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2001 warpOp,
"Reduction vector dimension must match was size.");
2002 if (!reductionOp.getType().isIntOrFloat())
2004 warpOp,
"Reduction distribution currently only supports floats and "
2007 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2013 if (reductionOp.getAcc()) {
2014 yieldValues.push_back(reductionOp.getAcc());
2015 retTypes.push_back(reductionOp.getAcc().getType());
2018 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2019 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2023 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2026 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2027 reductionOp.getKind(), newWarpOp.getWarpSize());
2028 if (reductionOp.getAcc()) {
2030 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2031 newWarpOp.getResult(newRetIndices[1]));
2038 DistributedReductionFn distributedReductionFn;
2049 void mlir::vector::populateDistributeTransferWriteOpPatterns(
2052 patterns.add<WarpOpTransferWrite>(
patterns.getContext(), distributionMapFn,
2053 maxNumElementsToExtract, benefit);
2056 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2058 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
PatternBenefit benefit,
2062 .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2063 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
2064 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
2065 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2067 patterns.add<WarpOpExtractScalar>(
patterns.getContext(), warpShuffleFromIdxFn,
2073 void mlir::vector::populateDistributeReduction(
2075 const DistributedReductionFn &distributedReductionFn,
2077 patterns.add<WarpOpReduction>(
patterns.getContext(), distributedReductionFn,
2084 return llvm::all_of(op->
getOperands(), definedOutside) &&
2088 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2089 Block *body = warpOp.getBody();
2092 llvm::SmallSetVector<Operation *, 8> opsToMove;
2095 auto isDefinedOutsideOfBody = [&](
Value value) {
2097 return (definingOp && opsToMove.count(definingOp)) ||
2098 warpOp.isDefinedOutsideOfRegion(value);
2104 bool hasVectorResult = llvm::any_of(op.
getResults(), [](
Value result) {
2105 return isa<VectorType>(result.getType());
2107 if (!hasVectorResult &&
canBeHoisted(&op, isDefinedOutsideOfBody))
2108 opsToMove.insert(&op);
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.