15 #include <type_traits>
35 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
36 #include "mlir/Conversion/Passes.h.inc"
40 using vector::TransferReadOp;
41 using vector::TransferWriteOp;
46 static const char kPassLabel[] =
"__vector_to_scf_lowering__";
49 static bool isTensorOp(VectorTransferOpInterface xferOp) {
50 if (isa<RankedTensorType>(xferOp.getShapedType())) {
51 if (isa<vector::TransferWriteOp>(xferOp)) {
53 assert(xferOp->getNumResults() > 0);
62 template <
typename OpTy>
68 LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
70 if (isTensorOp(xferOp) && !
options.lowerTensors) {
72 xferOp,
"lowering tensor transfers is disabled");
83 template <
typename OpTy>
84 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
86 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
87 auto map = xferOp.getPermutationMap();
88 if (
auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
89 return expr.getPosition();
91 assert(xferOp.isBroadcastDim(0) &&
92 "Expected AffineDimExpr or AffineConstantExpr");
99 template <
typename OpTy>
102 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
103 auto map = xferOp.getPermutationMap();
104 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
114 template <
typename OpTy>
117 typename OpTy::Adaptor adaptor(xferOp);
119 auto dim = unpackedDim(xferOp);
120 auto prevIndices = adaptor.getIndices();
121 indices.append(prevIndices.begin(), prevIndices.end());
124 bool isBroadcast = !dim.has_value();
127 bindDims(xferOp.getContext(), d0, d1);
128 Value offset = adaptor.getIndices()[*dim];
137 assert(value &&
"Expected non-empty value");
138 b.
create<scf::YieldOp>(loc, value);
140 b.
create<scf::YieldOp>(loc);
150 template <
typename OpTy>
152 if (!xferOp.getMask())
154 if (xferOp.getMaskType().getRank() != 1)
156 if (xferOp.isBroadcastDim(0))
160 return b.
create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
187 template <
typename OpTy>
188 static Value generateInBoundsCheck(
193 bool hasRetVal = !resultTypes.empty();
197 bool isBroadcast = !dim;
200 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
203 bindDims(xferOp.getContext(), d0, d1);
204 Value base = xferOp.getIndices()[*dim];
207 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
212 if (
auto maskCond = generateMaskCheck(b, xferOp, iv)) {
214 cond = lb.create<arith::AndIOp>(cond, maskCond);
221 auto check = lb.create<scf::IfOp>(
225 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
229 if (outOfBoundsCase) {
230 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
232 b.
create<scf::YieldOp>(loc);
236 return hasRetVal ? check.getResult(0) :
Value();
240 return inBoundsCase(b, loc);
245 template <
typename OpTy>
246 static void generateInBoundsCheck(
250 generateInBoundsCheck(
254 inBoundsCase(b, loc);
260 outOfBoundsCase(b, loc);
266 static ArrayAttr dropFirstElem(
OpBuilder &b, ArrayAttr attr) {
274 template <
typename OpTy>
275 static void maybeApplyPassLabel(
OpBuilder &b, OpTy newXferOp,
276 unsigned targetRank) {
277 if (newXferOp.getVectorType().getRank() > targetRank)
284 struct BufferAllocs {
293 assert(scope &&
"Expected op to be inside automatic allocation scope");
298 template <
typename OpTy>
299 static BufferAllocs allocBuffers(
OpBuilder &b, OpTy xferOp) {
304 "AutomaticAllocationScope with >1 regions");
309 result.dataBuffer = b.
create<memref::AllocaOp>(loc, bufferType);
311 if (xferOp.getMask()) {
313 auto maskBuffer = b.
create<memref::AllocaOp>(loc, maskType);
315 b.
create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
316 result.maskBuffer = b.
create<memref::LoadOp>(loc, maskBuffer,
ValueRange());
326 static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
327 auto vectorType = dyn_cast<VectorType>(type.getElementType());
330 if (vectorType.getScalableDims().front())
332 auto memrefShape = type.getShape();
334 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
335 newMemrefShape.push_back(vectorType.getDimSize(0));
342 template <
typename OpTy>
343 static Value getMaskBuffer(OpTy xferOp) {
344 assert(xferOp.getMask() &&
"Expected that transfer op has mask");
345 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
346 assert(loadOp &&
"Expected transfer op mask produced by LoadOp");
347 return loadOp.getMemRef();
351 template <
typename OpTy>
356 struct Strategy<TransferReadOp> {
359 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
360 assert(xferOp->hasOneUse() &&
"Expected exactly one use of TransferReadOp");
361 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
362 assert(storeOp &&
"Expected TransferReadOp result used by StoreOp");
374 return getStoreOp(xferOp).getMemRef();
378 static void getBufferIndices(TransferReadOp xferOp,
380 auto storeOp = getStoreOp(xferOp);
381 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
382 indices.append(prevIndices.begin(), prevIndices.end());
412 static TransferReadOp rewriteOp(
OpBuilder &b,
414 TransferReadOp xferOp,
Value buffer,
Value iv,
417 getBufferIndices(xferOp, storeIndices);
418 storeIndices.push_back(iv);
424 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
425 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
426 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
427 auto newXferOp = b.
create<vector::TransferReadOp>(
428 loc, vecType, xferOp.getBase(), xferIndices,
430 xferOp.getPadding(),
Value(), inBoundsAttr);
432 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
434 b.
create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
440 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferReadOp xferOp,
444 getBufferIndices(xferOp, storeIndices);
445 storeIndices.push_back(iv);
448 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
449 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
450 auto vec = b.
create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
451 b.
create<memref::StoreOp>(loc, vec, buffer, storeIndices);
459 rewriter.
eraseOp(getStoreOp(xferOp));
464 static Value initialLoopState(TransferReadOp xferOp) {
return Value(); }
469 struct Strategy<TransferWriteOp> {
478 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
479 assert(loadOp &&
"Expected transfer op vector produced by LoadOp");
480 return loadOp.getMemRef();
484 static void getBufferIndices(TransferWriteOp xferOp,
486 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
487 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
488 indices.append(prevIndices.begin(), prevIndices.end());
500 static TransferWriteOp rewriteOp(
OpBuilder &b,
502 TransferWriteOp xferOp,
Value buffer,
505 getBufferIndices(xferOp, loadIndices);
506 loadIndices.push_back(iv);
512 auto vec = b.
create<memref::LoadOp>(loc, buffer, loadIndices);
513 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
514 auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
515 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
516 auto newXferOp = b.
create<vector::TransferWriteOp>(
517 loc, type, vec, source, xferIndices,
521 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
527 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferWriteOp xferOp,
530 return isTensorOp(xferOp) ? loopState[0] :
Value();
536 if (isTensorOp(xferOp)) {
537 assert(forOp->getNumResults() == 1 &&
"Expected one for loop result");
538 rewriter.
replaceOp(xferOp, forOp->getResult(0));
545 static Value initialLoopState(TransferWriteOp xferOp) {
546 return isTensorOp(xferOp) ? xferOp.getBase() :
Value();
550 template <
typename OpTy>
551 static LogicalResult checkPrepareXferOp(OpTy xferOp,
PatternRewriter &rewriter,
553 if (xferOp->hasAttr(kPassLabel))
555 xferOp,
"kPassLabel is present (vector-to-scf lowering in progress)");
556 if (xferOp.getVectorType().getRank() <=
options.targetRank)
558 xferOp,
"xferOp vector rank <= transformation target rank");
559 if (xferOp.getVectorType().getScalableDims().front())
561 xferOp,
"Unpacking of the leading dimension into the memref is not yet "
562 "supported for scalable dims");
563 if (isTensorOp(xferOp) && !
options.lowerTensors)
565 xferOp,
"Unpacking for tensors has been disabled.");
566 if (xferOp.getVectorType().getElementType() !=
567 xferOp.getShapedType().getElementType())
569 xferOp,
"Mismatching source and destination element types.");
597 struct PrepareTransferReadConversion
598 :
public VectorToSCFPattern<TransferReadOp> {
599 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
601 LogicalResult matchAndRewrite(TransferReadOp xferOp,
603 if (checkPrepareXferOp(xferOp, rewriter,
options).failed())
605 xferOp,
"checkPrepareXferOp conditions not met!");
607 auto buffers = allocBuffers(rewriter, xferOp);
608 auto *newXfer = rewriter.
clone(*xferOp.getOperation());
610 if (xferOp.getMask()) {
611 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
616 rewriter.
create<memref::StoreOp>(loc, newXfer->getResult(0),
647 struct PrepareTransferWriteConversion
648 :
public VectorToSCFPattern<TransferWriteOp> {
649 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
651 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
653 if (checkPrepareXferOp(xferOp, rewriter,
options).failed())
655 xferOp,
"checkPrepareXferOp conditions not met!");
658 auto buffers = allocBuffers(rewriter, xferOp);
659 rewriter.
create<memref::StoreOp>(loc, xferOp.getVector(),
661 auto loadedVec = rewriter.
create<memref::LoadOp>(loc, buffers.dataBuffer);
663 xferOp.getValueToStoreMutable().assign(loadedVec);
664 xferOp->setAttr(kPassLabel, rewriter.
getUnitAttr());
667 if (xferOp.getMask()) {
669 xferOp.getMaskMutable().assign(buffers.maskBuffer);
704 struct DecomposePrintOpConversion :
public VectorToSCFPattern<vector::PrintOp> {
705 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
706 LogicalResult matchAndRewrite(vector::PrintOp
printOp,
711 VectorType vectorType = dyn_cast<VectorType>(
printOp.getPrintType());
721 if (vectorType.getRank() > 1 && vectorType.isScalable())
725 auto value =
printOp.getSource();
727 if (
auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
731 auto width = intTy.getWidth();
732 auto legalWidth = llvm::NextPowerOf2(
std::max(8u, width) - 1);
734 intTy.getSignedness());
736 auto signlessSourceVectorType =
737 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
738 auto signlessTargetVectorType =
739 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
740 auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
741 value = rewriter.
create<vector::BitCastOp>(loc, signlessSourceVectorType,
743 if (value.
getType() != signlessTargetVectorType) {
744 if (width == 1 || intTy.isUnsigned())
745 value = rewriter.
create<arith::ExtUIOp>(loc, signlessTargetVectorType,
748 value = rewriter.
create<arith::ExtSIOp>(loc, signlessTargetVectorType,
751 value = rewriter.
create<vector::BitCastOp>(loc, targetVectorType, value);
752 vectorType = targetVectorType;
755 auto scalableDimensions = vectorType.getScalableDims();
756 auto shape = vectorType.getShape();
757 constexpr int64_t singletonShape[] = {1};
758 if (vectorType.getRank() == 0)
759 shape = singletonShape;
761 if (vectorType.getRank() != 1) {
765 auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
766 std::multiplies<int64_t>());
767 auto flatVectorType =
769 value = rewriter.
create<vector::ShapeCastOp>(loc, flatVectorType, value);
772 vector::PrintOp firstClose;
774 for (
unsigned d = 0; d < shape.size(); d++) {
776 Value lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
777 Value upperBound = rewriter.
create<arith::ConstantIndexOp>(loc, shape[d]);
778 Value step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
779 if (!scalableDimensions.empty() && scalableDimensions[d]) {
780 auto vscale = rewriter.
create<vector::VectorScaleOp>(
782 upperBound = rewriter.
create<arith::MulIOp>(loc, upperBound, vscale);
784 auto lastIndex = rewriter.
create<arith::SubIOp>(loc, upperBound, step);
787 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
789 rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
791 loc, vector::PrintPunctuation::Close);
795 auto loopIdx = loop.getInductionVar();
796 loopIndices.push_back(loopIdx);
800 auto notLastIndex = rewriter.
create<arith::CmpIOp>(
801 loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
802 rewriter.
create<scf::IfOp>(loc, notLastIndex,
804 builder.create<vector::PrintOp>(
805 loc, vector::PrintPunctuation::Comma);
806 builder.create<scf::YieldOp>(loc);
815 auto currentStride = 1;
816 for (
int d = shape.size() - 1; d >= 0; d--) {
817 auto stride = rewriter.
create<arith::ConstantIndexOp>(loc, currentStride);
818 auto index = rewriter.
create<arith::MulIOp>(loc, stride, loopIndices[d]);
820 flatIndex = rewriter.
create<arith::AddIOp>(loc, flatIndex, index);
823 currentStride *= shape[d];
828 rewriter.
create<vector::ExtractElementOp>(loc, value, flatIndex);
829 rewriter.
create<vector::PrintOp>(loc, element,
830 vector::PrintPunctuation::NoPunctuation);
833 rewriter.
create<vector::PrintOp>(loc,
printOp.getPunctuation());
838 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
840 IntegerType::Signless);
873 template <
typename OpTy>
874 struct TransferOpConversion :
public VectorToSCFPattern<OpTy> {
875 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
880 this->setHasBoundedRewriteRecursion();
883 static void getMaskBufferLoadIndices(OpTy xferOp,
Value castedMaskBuffer,
886 assert(xferOp.getMask() &&
"Expected transfer op to have mask");
892 Value maskBuffer = getMaskBuffer(xferOp);
895 if (
auto loadOp = dyn_cast<memref::LoadOp>(user)) {
897 loadIndices.append(prevIndices.begin(), prevIndices.end());
904 if (!xferOp.isBroadcastDim(0))
905 loadIndices.push_back(iv);
908 LogicalResult matchAndRewrite(OpTy xferOp,
910 if (!xferOp->hasAttr(kPassLabel))
912 xferOp,
"kPassLabel is present (progressing lowering in progress)");
917 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.
getType());
918 FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
919 if (failed(castedDataType))
921 "Failed to unpack one vector dim.");
923 auto castedDataBuffer =
924 locB.
create<vector::TypeCastOp>(*castedDataType, dataBuffer);
927 Value castedMaskBuffer;
928 if (xferOp.getMask()) {
929 Value maskBuffer = getMaskBuffer(xferOp);
930 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
936 castedMaskBuffer = maskBuffer;
940 auto maskBufferType = cast<MemRefType>(maskBuffer.
getType());
941 MemRefType castedMaskType = *unpackOneDim(maskBufferType);
943 locB.
create<vector::TypeCastOp>(castedMaskType, maskBuffer);
948 auto lb = locB.
create<arith::ConstantIndexOp>(0);
949 auto ub = locB.
create<arith::ConstantIndexOp>(
950 castedDataType->getDimSize(castedDataType->getRank() - 1));
951 auto step = locB.
create<arith::ConstantIndexOp>(1);
954 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
957 auto result = locB.
create<scf::ForOp>(
960 Type stateType = loopState.empty() ?
Type() : loopState[0].getType();
962 auto result = generateInBoundsCheck(
963 b, xferOp, iv, unpackedDim(xferOp),
968 OpTy newXfer = Strategy<OpTy>::rewriteOp(
969 b, this->options, xferOp, castedDataBuffer, iv, loopState);
975 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
976 xferOp.getMaskType().getRank() > 1)) {
981 getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
983 auto mask = b.
create<memref::LoadOp>(loc, castedMaskBuffer,
986 newXfer.getMaskMutable().assign(mask);
990 return loopState.empty() ?
Value() : newXfer->getResult(0);
994 return Strategy<OpTy>::handleOutOfBoundsDim(
995 b, xferOp, castedDataBuffer, iv, loopState);
998 maybeYieldValue(b, loc, !loopState.empty(), result);
1001 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
1008 template <
typename VscaleConstantBuilder>
1009 static FailureOr<SmallVector<OpFoldResult>>
1010 getMaskDimSizes(
Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1013 if (
auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>()) {
1014 return llvm::map_to_vector(createMaskOp.getOperands(), [](
Value dimSize) {
1015 return OpFoldResult(dimSize);
1018 if (
auto constantMask = mask.
getDefiningOp<vector::ConstantMaskOp>()) {
1020 VectorType maskType = constantMask.getVectorType();
1022 return llvm::map_to_vector(
1023 constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1025 if (maskType.getScalableDims()[dimIdx++])
1026 return OpFoldResult(createVscaleMultiple(dimSize));
1027 return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1070 struct ScalableTransposeTransferWriteConversion
1071 : VectorToSCFPattern<vector::TransferWriteOp> {
1072 using VectorToSCFPattern::VectorToSCFPattern;
1074 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1076 if (failed(checkLowerTensors(writeOp, rewriter)))
1079 VectorType vectorType = writeOp.getVectorType();
1086 writeOp,
"expected vector of the form vector<[N]xMxty>");
1089 auto permutationMap = writeOp.getPermutationMap();
1090 if (!permutationMap.isIdentity()) {
1092 writeOp,
"non-identity permutations are unsupported (lower first)");
1098 if (!writeOp.isDimInBounds(0)) {
1100 writeOp,
"out-of-bounds dims are unsupported (use masking)");
1103 Value vector = writeOp.getVector();
1104 auto transposeOp = vector.
getDefiningOp<vector::TransposeOp>();
1110 auto loc = writeOp.getLoc();
1111 auto createVscaleMultiple =
1114 auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1115 if (failed(maskDims)) {
1117 "failed to resolve mask dims");
1120 int64_t fixedDimSize = vectorType.getDimSize(1);
1121 auto fixedDimOffsets = llvm::seq(fixedDimSize);
1124 auto transposeSource = transposeOp.getVector();
1126 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) ->
Value {
1127 return rewriter.
create<vector::ExtractOp>(loc, transposeSource, idx);
1131 auto lb = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1134 ?
Value(createVscaleMultiple(vectorType.getDimSize(0)))
1136 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
1140 Value sliceMask =
nullptr;
1141 if (!maskDims->empty()) {
1142 sliceMask = rewriter.
create<vector::CreateMaskOp>(
1143 loc, sliceType.clone(rewriter.
getI1Type()),
1147 Value initDest = isTensorOp(writeOp) ? writeOp.getBase() :
Value{};
1149 auto result = rewriter.
create<scf::ForOp>(
1150 loc, lb, ub, step, initLoopArgs,
1158 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) ->
Value {
1159 return b.create<vector::ExtractOp>(
1160 loc, transposeSourceSlices[idx], iv);
1162 auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
1167 loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
1168 auto newWriteOp = b.create<vector::TransferWriteOp>(
1169 loc, sliceVec, dest, xferIndices,
1172 newWriteOp.getMaskMutable().assign(sliceMask);
1175 b.create<scf::YieldOp>(loc, loopIterArgs.empty()
1177 : newWriteOp.getResult());
1180 if (isTensorOp(writeOp))
1195 template <
typename OpTy>
1196 static void maybeAssignMask(
OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1198 if (!xferOp.getMask())
1201 if (xferOp.isBroadcastDim(0)) {
1204 newXferOp.getMaskMutable().assign(xferOp.getMask());
1208 if (xferOp.getMaskType().getRank() > 1) {
1215 auto newMask = b.
create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
1216 newXferOp.getMaskMutable().assign(newMask);
1252 struct UnrollTransferReadConversion
1253 :
public VectorToSCFPattern<TransferReadOp> {
1254 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1259 setHasBoundedRewriteRecursion();
1265 TransferReadOp xferOp)
const {
1266 if (
auto insertOp = getInsertOp(xferOp))
1267 return insertOp.getDest();
1269 return rewriter.
create<vector::SplatOp>(loc, xferOp.getVectorType(),
1270 xferOp.getPadding());
1275 vector::InsertOp getInsertOp(TransferReadOp xferOp)
const {
1276 if (xferOp->hasOneUse()) {
1278 if (
auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1282 return vector::InsertOp();
1287 void getInsertionIndices(TransferReadOp xferOp,
1289 if (
auto insertOp = getInsertOp(xferOp)) {
1290 auto pos = insertOp.getMixedPosition();
1291 indices.append(pos.begin(), pos.end());
1297 LogicalResult matchAndRewrite(TransferReadOp xferOp,
1299 if (xferOp.getVectorType().getRank() <=
options.targetRank)
1301 xferOp,
"vector rank is less or equal to target rank");
1302 if (failed(checkLowerTensors(xferOp, rewriter)))
1304 if (xferOp.getVectorType().getElementType() !=
1305 xferOp.getShapedType().getElementType())
1307 xferOp,
"not yet supported: element type mismatch");
1308 auto xferVecType = xferOp.getVectorType();
1309 if (xferVecType.getScalableDims()[0]) {
1311 xferOp,
"scalable dimensions cannot be unrolled at compile time");
1314 auto insertOp = getInsertOp(xferOp);
1315 auto vec = buildResultVector(rewriter, xferOp);
1316 auto vecType = dyn_cast<VectorType>(vec.getType());
1320 int64_t dimSize = xferVecType.getShape()[0];
1324 for (int64_t i = 0; i < dimSize; ++i) {
1325 Value iv = rewriter.
create<arith::ConstantIndexOp>(loc, i);
1327 vec = generateInBoundsCheck(
1328 rewriter, xferOp, iv, unpackedDim(xferOp),
TypeRange(vecType),
1337 getInsertionIndices(xferOp, insertionIndices);
1340 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1341 auto newXferOp = b.
create<vector::TransferReadOp>(
1342 loc, newXferVecType, xferOp.getBase(), xferIndices,
1344 xferOp.getPadding(),
Value(), inBoundsAttr);
1345 maybeAssignMask(b, xferOp, newXferOp, i);
1346 return b.
create<vector::InsertOp>(loc, newXferOp, vec,
1394 struct UnrollTransferWriteConversion
1395 :
public VectorToSCFPattern<TransferWriteOp> {
1396 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1401 setHasBoundedRewriteRecursion();
1405 Value getDataVector(TransferWriteOp xferOp)
const {
1406 if (
auto extractOp = getExtractOp(xferOp))
1407 return extractOp.getVector();
1408 return xferOp.getVector();
1412 vector::ExtractOp getExtractOp(TransferWriteOp xferOp)
const {
1413 if (
auto *op = xferOp.getVector().getDefiningOp())
1414 return dyn_cast<vector::ExtractOp>(op);
1415 return vector::ExtractOp();
1420 void getExtractionIndices(TransferWriteOp xferOp,
1422 if (
auto extractOp = getExtractOp(xferOp)) {
1423 auto pos = extractOp.getMixedPosition();
1424 indices.append(pos.begin(), pos.end());
1430 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1432 VectorType inputVectorTy = xferOp.getVectorType();
1434 if (inputVectorTy.getRank() <=
options.targetRank)
1437 if (failed(checkLowerTensors(xferOp, rewriter)))
1440 if (inputVectorTy.getElementType() !=
1441 xferOp.getShapedType().getElementType())
1444 auto vec = getDataVector(xferOp);
1445 if (inputVectorTy.getScalableDims()[0]) {
1450 int64_t dimSize = inputVectorTy.getShape()[0];
1451 Value source = xferOp.getBase();
1452 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
1456 for (int64_t i = 0; i < dimSize; ++i) {
1457 Value iv = rewriter.
create<arith::ConstantIndexOp>(loc, i);
1459 auto updatedSource = generateInBoundsCheck(
1460 rewriter, xferOp, iv, unpackedDim(xferOp),
1470 getExtractionIndices(xferOp, extractionIndices);
1474 b.
create<vector::ExtractOp>(loc, vec, extractionIndices);
1475 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1477 if (inputVectorTy.getRank() == 1) {
1481 xferVec = b.
create<vector::BroadcastOp>(
1484 xferVec = extracted;
1486 auto newXferOp = b.
create<vector::TransferWriteOp>(
1487 loc, sourceType, xferVec, source, xferIndices,
1491 maybeAssignMask(b, xferOp, newXferOp, i);
1493 return isTensorOp(xferOp) ? newXferOp->getResult(0) :
Value();
1497 return isTensorOp(xferOp) ? source :
Value();
1500 if (isTensorOp(xferOp))
1501 source = updatedSource;
1504 if (isTensorOp(xferOp))
1521 template <
typename OpTy>
1522 static std::optional<int64_t>
1525 auto indices = xferOp.getIndices();
1526 auto map = xferOp.getPermutationMap();
1527 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
1529 memrefIndices.append(indices.begin(), indices.end());
1530 assert(map.getNumResults() == 1 &&
1531 "Expected 1 permutation map result for 1D transfer");
1532 if (
auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1534 auto dim = expr.getPosition();
1536 bindDims(xferOp.getContext(), d0, d1);
1537 Value offset = memrefIndices[dim];
1538 memrefIndices[dim] =
1543 assert(xferOp.isBroadcastDim(0) &&
1544 "Expected AffineDimExpr or AffineConstantExpr");
1545 return std::nullopt;
1550 template <
typename OpTy>
1555 struct Strategy1d<TransferReadOp> {
1557 TransferReadOp xferOp,
Value iv,
1560 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1561 auto vec = loopState[0];
1565 auto nextVec = generateInBoundsCheck(
1566 b, xferOp, iv, dim,
TypeRange(xferOp.getVectorType()),
1569 Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices);
1570 return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1574 b.
create<scf::YieldOp>(loc, nextVec);
1577 static Value initialLoopState(
OpBuilder &b, TransferReadOp xferOp) {
1580 return b.
create<vector::SplatOp>(loc, xferOp.getVectorType(),
1581 xferOp.getPadding());
1587 struct Strategy1d<TransferWriteOp> {
1589 TransferWriteOp xferOp,
Value iv,
1592 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1595 generateInBoundsCheck(
1599 b.
create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1600 b.
create<memref::StoreOp>(loc, val, xferOp.getBase(), indices);
1602 b.
create<scf::YieldOp>(loc);
1605 static Value initialLoopState(
OpBuilder &b, TransferWriteOp xferOp) {
1641 template <
typename OpTy>
1642 struct TransferOp1dConversion :
public VectorToSCFPattern<OpTy> {
1643 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1645 LogicalResult matchAndRewrite(OpTy xferOp,
1648 if (xferOp.getTransferRank() == 0)
1650 auto map = xferOp.getPermutationMap();
1651 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1655 if (xferOp.getVectorType().getRank() != 1)
1657 if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
1662 auto vecType = xferOp.getVectorType();
1663 auto lb = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1665 rewriter.
create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1666 if (vecType.isScalable()) {
1669 ub = rewriter.
create<arith::MulIOp>(loc, ub, vscale);
1671 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
1672 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1678 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1691 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1692 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1695 patterns.add<lowering_n_d::PrepareTransferReadConversion,
1696 lowering_n_d::PrepareTransferWriteConversion,
1697 lowering_n_d::TransferOpConversion<TransferReadOp>,
1698 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1702 patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1705 if (
options.targetRank == 1) {
1706 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1707 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1710 patterns.add<lowering_n_d::DecomposePrintOpConversion>(
patterns.getContext(),
1716 struct ConvertVectorToSCFPass
1717 :
public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1718 ConvertVectorToSCFPass() =
default;
1720 this->fullUnroll =
options.unroll;
1721 this->targetRank =
options.targetRank;
1722 this->lowerTensors =
options.lowerTensors;
1723 this->lowerScalable =
options.lowerScalable;
1726 void runOnOperation()
override {
1729 options.targetRank = targetRank;
1730 options.lowerTensors = lowerTensors;
1731 options.lowerScalable = lowerScalable;
1736 lowerTransferPatterns);
1738 std::move(lowerTransferPatterns));
1748 std::unique_ptr<Pass>
1750 return std::make_unique<ConvertVectorToSCFPass>(
options);
MLIR_CRUNNERUTILS_EXPORT void printClose()
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static Operation * getAutomaticAllocationScope(Operation *op)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A trait of region holding operations that define a new scope for automatic allocations,...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Create a pass to convert a subset of vector ops to SCF.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...