15 #include <type_traits>
35 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
36 #include "mlir/Conversion/Passes.h.inc"
40 using vector::TransferReadOp;
41 using vector::TransferWriteOp;
46 static const char kPassLabel[] =
"__vector_to_scf_lowering__";
49 static bool isTensorOp(VectorTransferOpInterface xferOp) {
50 if (isa<RankedTensorType>(xferOp.getShapedType())) {
51 if (isa<vector::TransferWriteOp>(xferOp)) {
53 assert(xferOp->getNumResults() > 0);
62 template <
typename OpTy>
68 LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
70 if (isTensorOp(xferOp) && !
options.lowerTensors) {
72 xferOp,
"lowering tensor transfers is disabled");
83 template <
typename OpTy>
84 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
86 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
87 auto map = xferOp.getPermutationMap();
88 if (
auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
89 return expr.getPosition();
91 assert(xferOp.isBroadcastDim(0) &&
92 "Expected AffineDimExpr or AffineConstantExpr");
99 template <
typename OpTy>
102 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
103 auto map = xferOp.getPermutationMap();
104 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
114 template <
typename OpTy>
117 typename OpTy::Adaptor adaptor(xferOp);
119 auto dim = unpackedDim(xferOp);
120 auto prevIndices = adaptor.getIndices();
121 indices.append(prevIndices.begin(), prevIndices.end());
124 bool isBroadcast = !dim.has_value();
127 bindDims(xferOp.getContext(), d0, d1);
128 Value offset = adaptor.getIndices()[*dim];
137 assert(value &&
"Expected non-empty value");
138 b.
create<scf::YieldOp>(loc, value);
140 b.
create<scf::YieldOp>(loc);
150 template <
typename OpTy>
152 if (!xferOp.getMask())
154 if (xferOp.getMaskType().getRank() != 1)
156 if (xferOp.isBroadcastDim(0))
160 return b.
create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
187 template <
typename OpTy>
188 static Value generateInBoundsCheck(
193 bool hasRetVal = !resultTypes.empty();
197 bool isBroadcast = !dim;
200 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
204 bindDims(xferOp.getContext(), d0, d1);
205 Value base = xferOp.getIndices()[*dim];
208 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
213 if (
auto maskCond = generateMaskCheck(b, xferOp, iv)) {
215 cond = lb.create<arith::AndIOp>(cond, maskCond);
222 auto check = lb.create<scf::IfOp>(
226 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
230 if (outOfBoundsCase) {
231 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
233 b.
create<scf::YieldOp>(loc);
237 return hasRetVal ? check.getResult(0) :
Value();
241 return inBoundsCase(b, loc);
246 template <
typename OpTy>
247 static void generateInBoundsCheck(
251 generateInBoundsCheck(
255 inBoundsCase(b, loc);
261 outOfBoundsCase(b, loc);
267 static ArrayAttr dropFirstElem(
OpBuilder &b, ArrayAttr attr) {
275 template <
typename OpTy>
276 static void maybeApplyPassLabel(
OpBuilder &b, OpTy newXferOp,
277 unsigned targetRank) {
278 if (newXferOp.getVectorType().getRank() > targetRank)
285 struct BufferAllocs {
294 assert(scope &&
"Expected op to be inside automatic allocation scope");
299 template <
typename OpTy>
300 static BufferAllocs allocBuffers(
OpBuilder &b, OpTy xferOp) {
305 "AutomaticAllocationScope with >1 regions");
310 result.dataBuffer = b.
create<memref::AllocaOp>(loc, bufferType);
312 if (xferOp.getMask()) {
314 auto maskBuffer = b.
create<memref::AllocaOp>(loc, maskType);
316 b.
create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
317 result.maskBuffer = b.
create<memref::LoadOp>(loc, maskBuffer,
ValueRange());
327 static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
328 auto vectorType = dyn_cast<VectorType>(type.getElementType());
331 if (vectorType.getScalableDims().front())
333 auto memrefShape = type.getShape();
335 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
336 newMemrefShape.push_back(vectorType.getDimSize(0));
343 template <
typename OpTy>
344 static Value getMaskBuffer(OpTy xferOp) {
345 assert(xferOp.getMask() &&
"Expected that transfer op has mask");
346 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
347 assert(loadOp &&
"Expected transfer op mask produced by LoadOp");
348 return loadOp.getMemRef();
352 template <
typename OpTy>
357 struct Strategy<TransferReadOp> {
360 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
361 assert(xferOp->hasOneUse() &&
"Expected exactly one use of TransferReadOp");
362 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
363 assert(storeOp &&
"Expected TransferReadOp result used by StoreOp");
375 return getStoreOp(xferOp).getMemRef();
379 static void getBufferIndices(TransferReadOp xferOp,
381 auto storeOp = getStoreOp(xferOp);
382 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
383 indices.append(prevIndices.begin(), prevIndices.end());
413 static TransferReadOp rewriteOp(
OpBuilder &b,
415 TransferReadOp xferOp,
Value buffer,
Value iv,
418 getBufferIndices(xferOp, storeIndices);
419 storeIndices.push_back(iv);
425 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
426 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
427 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
428 auto newXferOp = b.
create<vector::TransferReadOp>(
429 loc, vecType, xferOp.getSource(), xferIndices,
431 xferOp.getPadding(),
Value(), inBoundsAttr);
433 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
435 b.
create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
441 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferReadOp xferOp,
445 getBufferIndices(xferOp, storeIndices);
446 storeIndices.push_back(iv);
449 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
450 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
451 auto vec = b.
create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
452 b.
create<memref::StoreOp>(loc, vec, buffer, storeIndices);
460 rewriter.
eraseOp(getStoreOp(xferOp));
465 static Value initialLoopState(TransferReadOp xferOp) {
return Value(); }
470 struct Strategy<TransferWriteOp> {
479 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
480 assert(loadOp &&
"Expected transfer op vector produced by LoadOp");
481 return loadOp.getMemRef();
485 static void getBufferIndices(TransferWriteOp xferOp,
487 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
488 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
489 indices.append(prevIndices.begin(), prevIndices.end());
501 static TransferWriteOp rewriteOp(
OpBuilder &b,
503 TransferWriteOp xferOp,
Value buffer,
506 getBufferIndices(xferOp, loadIndices);
507 loadIndices.push_back(iv);
513 auto vec = b.
create<memref::LoadOp>(loc, buffer, loadIndices);
514 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
515 auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
516 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
517 auto newXferOp = b.
create<vector::TransferWriteOp>(
518 loc, type, vec, source, xferIndices,
522 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
528 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferWriteOp xferOp,
531 return isTensorOp(xferOp) ? loopState[0] :
Value();
537 if (isTensorOp(xferOp)) {
538 assert(forOp->getNumResults() == 1 &&
"Expected one for loop result");
539 rewriter.
replaceOp(xferOp, forOp->getResult(0));
546 static Value initialLoopState(TransferWriteOp xferOp) {
547 return isTensorOp(xferOp) ? xferOp.getSource() :
Value();
551 template <
typename OpTy>
552 static LogicalResult checkPrepareXferOp(OpTy xferOp,
PatternRewriter &rewriter,
554 if (xferOp->hasAttr(kPassLabel))
556 xferOp,
"kPassLabel is present (vector-to-scf lowering in progress)");
557 if (xferOp.getVectorType().getRank() <=
options.targetRank)
559 xferOp,
"xferOp vector rank <= transformation target rank");
560 if (xferOp.getVectorType().getScalableDims().front())
562 xferOp,
"Unpacking of the leading dimension into the memref is not yet "
563 "supported for scalable dims");
564 if (isTensorOp(xferOp) && !
options.lowerTensors)
566 xferOp,
"Unpacking for tensors has been disabled.");
567 if (xferOp.getVectorType().getElementType() !=
568 xferOp.getShapedType().getElementType())
570 xferOp,
"Mismatching source and destination element types.");
598 struct PrepareTransferReadConversion
599 :
public VectorToSCFPattern<TransferReadOp> {
600 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
602 LogicalResult matchAndRewrite(TransferReadOp xferOp,
604 if (checkPrepareXferOp(xferOp, rewriter,
options).failed())
606 xferOp,
"checkPrepareXferOp conditions not met!");
608 auto buffers = allocBuffers(rewriter, xferOp);
609 auto *newXfer = rewriter.
clone(*xferOp.getOperation());
611 if (xferOp.getMask()) {
612 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
617 rewriter.
create<memref::StoreOp>(loc, newXfer->getResult(0),
648 struct PrepareTransferWriteConversion
649 :
public VectorToSCFPattern<TransferWriteOp> {
650 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
652 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
654 if (checkPrepareXferOp(xferOp, rewriter,
options).failed())
656 xferOp,
"checkPrepareXferOp conditions not met!");
659 auto buffers = allocBuffers(rewriter, xferOp);
660 rewriter.
create<memref::StoreOp>(loc, xferOp.getVector(),
662 auto loadedVec = rewriter.
create<memref::LoadOp>(loc, buffers.dataBuffer);
664 xferOp.getVectorMutable().assign(loadedVec);
665 xferOp->setAttr(kPassLabel, rewriter.
getUnitAttr());
668 if (xferOp.getMask()) {
670 xferOp.getMaskMutable().assign(buffers.maskBuffer);
705 struct DecomposePrintOpConversion :
public VectorToSCFPattern<vector::PrintOp> {
706 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
707 LogicalResult matchAndRewrite(vector::PrintOp
printOp,
712 VectorType vectorType = dyn_cast<VectorType>(
printOp.getPrintType());
722 if (vectorType.getRank() > 1 && vectorType.isScalable())
726 auto value =
printOp.getSource();
728 if (
auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
732 auto width = intTy.getWidth();
733 auto legalWidth = llvm::NextPowerOf2(
std::max(8u, width) - 1);
735 intTy.getSignedness());
737 auto signlessSourceVectorType =
738 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
739 auto signlessTargetVectorType =
740 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
741 auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
742 value = rewriter.
create<vector::BitCastOp>(loc, signlessSourceVectorType,
744 if (value.
getType() != signlessTargetVectorType) {
745 if (width == 1 || intTy.isUnsigned())
746 value = rewriter.
create<arith::ExtUIOp>(loc, signlessTargetVectorType,
749 value = rewriter.
create<arith::ExtSIOp>(loc, signlessTargetVectorType,
752 value = rewriter.
create<vector::BitCastOp>(loc, targetVectorType, value);
753 vectorType = targetVectorType;
756 auto scalableDimensions = vectorType.getScalableDims();
757 auto shape = vectorType.getShape();
758 constexpr int64_t singletonShape[] = {1};
759 if (vectorType.getRank() == 0)
760 shape = singletonShape;
762 if (vectorType.getRank() != 1) {
766 auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
767 std::multiplies<int64_t>());
768 auto flatVectorType =
770 value = rewriter.
create<vector::ShapeCastOp>(loc, flatVectorType, value);
773 vector::PrintOp firstClose;
775 for (
unsigned d = 0; d < shape.size(); d++) {
777 Value lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
778 Value upperBound = rewriter.
create<arith::ConstantIndexOp>(loc, shape[d]);
779 Value step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
780 if (!scalableDimensions.empty() && scalableDimensions[d]) {
781 auto vscale = rewriter.
create<vector::VectorScaleOp>(
783 upperBound = rewriter.
create<arith::MulIOp>(loc, upperBound, vscale);
785 auto lastIndex = rewriter.
create<arith::SubIOp>(loc, upperBound, step);
788 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
790 rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
792 loc, vector::PrintPunctuation::Close);
796 auto loopIdx = loop.getInductionVar();
797 loopIndices.push_back(loopIdx);
801 auto notLastIndex = rewriter.
create<arith::CmpIOp>(
802 loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
803 rewriter.
create<scf::IfOp>(loc, notLastIndex,
805 builder.create<vector::PrintOp>(
806 loc, vector::PrintPunctuation::Comma);
807 builder.create<scf::YieldOp>(loc);
816 auto currentStride = 1;
817 for (
int d = shape.size() - 1; d >= 0; d--) {
818 auto stride = rewriter.
create<arith::ConstantIndexOp>(loc, currentStride);
819 auto index = rewriter.
create<arith::MulIOp>(loc, stride, loopIndices[d]);
821 flatIndex = rewriter.
create<arith::AddIOp>(loc, flatIndex, index);
824 currentStride *= shape[d];
829 rewriter.
create<vector::ExtractElementOp>(loc, value, flatIndex);
830 rewriter.
create<vector::PrintOp>(loc, element,
831 vector::PrintPunctuation::NoPunctuation);
834 rewriter.
create<vector::PrintOp>(loc,
printOp.getPunctuation());
839 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
841 IntegerType::Signless);
874 template <
typename OpTy>
875 struct TransferOpConversion :
public VectorToSCFPattern<OpTy> {
876 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
881 this->setHasBoundedRewriteRecursion();
884 static void getMaskBufferLoadIndices(OpTy xferOp,
Value castedMaskBuffer,
887 assert(xferOp.getMask() &&
"Expected transfer op to have mask");
893 Value maskBuffer = getMaskBuffer(xferOp);
896 if (
auto loadOp = dyn_cast<memref::LoadOp>(user)) {
898 loadIndices.append(prevIndices.begin(), prevIndices.end());
905 if (!xferOp.isBroadcastDim(0))
906 loadIndices.push_back(iv);
909 LogicalResult matchAndRewrite(OpTy xferOp,
911 if (!xferOp->hasAttr(kPassLabel))
913 xferOp,
"kPassLabel is present (progressing lowering in progress)");
918 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.
getType());
919 FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
920 if (failed(castedDataType))
922 "Failed to unpack one vector dim.");
924 auto castedDataBuffer =
925 locB.
create<vector::TypeCastOp>(*castedDataType, dataBuffer);
928 Value castedMaskBuffer;
929 if (xferOp.getMask()) {
930 Value maskBuffer = getMaskBuffer(xferOp);
931 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
937 castedMaskBuffer = maskBuffer;
941 auto maskBufferType = cast<MemRefType>(maskBuffer.
getType());
942 MemRefType castedMaskType = *unpackOneDim(maskBufferType);
944 locB.
create<vector::TypeCastOp>(castedMaskType, maskBuffer);
949 auto lb = locB.
create<arith::ConstantIndexOp>(0);
950 auto ub = locB.
create<arith::ConstantIndexOp>(
951 castedDataType->getDimSize(castedDataType->getRank() - 1));
952 auto step = locB.
create<arith::ConstantIndexOp>(1);
955 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
958 auto result = locB.
create<scf::ForOp>(
961 Type stateType = loopState.empty() ?
Type() : loopState[0].getType();
963 auto result = generateInBoundsCheck(
964 b, xferOp, iv, unpackedDim(xferOp),
969 OpTy newXfer = Strategy<OpTy>::rewriteOp(
970 b, this->options, xferOp, castedDataBuffer, iv, loopState);
976 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
977 xferOp.getMaskType().getRank() > 1)) {
982 getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
984 auto mask = b.
create<memref::LoadOp>(loc, castedMaskBuffer,
987 newXfer.getMaskMutable().assign(mask);
991 return loopState.empty() ?
Value() : newXfer->getResult(0);
995 return Strategy<OpTy>::handleOutOfBoundsDim(
996 b, xferOp, castedDataBuffer, iv, loopState);
999 maybeYieldValue(b, loc, !loopState.empty(), result);
1002 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
1009 template <
typename VscaleConstantBuilder>
1010 static FailureOr<SmallVector<OpFoldResult>>
1011 getMaskDimSizes(
Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1014 if (
auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>()) {
1015 return llvm::map_to_vector(createMaskOp.getOperands(), [](
Value dimSize) {
1016 return OpFoldResult(dimSize);
1019 if (
auto constantMask = mask.
getDefiningOp<vector::ConstantMaskOp>()) {
1021 VectorType maskType = constantMask.getVectorType();
1023 return llvm::map_to_vector(
1024 constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1026 if (maskType.getScalableDims()[dimIdx++])
1027 return OpFoldResult(createVscaleMultiple(dimSize));
1028 return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1071 struct ScalableTransposeTransferWriteConversion
1072 : VectorToSCFPattern<vector::TransferWriteOp> {
1073 using VectorToSCFPattern::VectorToSCFPattern;
1075 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1077 if (failed(checkLowerTensors(writeOp, rewriter)))
1080 VectorType vectorType = writeOp.getVectorType();
1087 writeOp,
"expected vector of the form vector<[N]xMxty>");
1090 auto permutationMap = writeOp.getPermutationMap();
1091 if (!permutationMap.isIdentity()) {
1093 writeOp,
"non-identity permutations are unsupported (lower first)");
1099 if (!writeOp.isDimInBounds(0)) {
1101 writeOp,
"out-of-bounds dims are unsupported (use masking)");
1104 Value vector = writeOp.getVector();
1105 auto transposeOp = vector.
getDefiningOp<vector::TransposeOp>();
1111 auto loc = writeOp.getLoc();
1112 auto createVscaleMultiple =
1115 auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1116 if (failed(maskDims)) {
1118 "failed to resolve mask dims");
1121 int64_t fixedDimSize = vectorType.getDimSize(1);
1122 auto fixedDimOffsets = llvm::seq(fixedDimSize);
1125 auto transposeSource = transposeOp.getVector();
1127 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) ->
Value {
1128 return rewriter.
create<vector::ExtractOp>(loc, transposeSource, idx);
1132 auto lb = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1135 ?
Value(createVscaleMultiple(vectorType.getDimSize(0)))
1137 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
1141 Value sliceMask =
nullptr;
1142 if (!maskDims->empty()) {
1143 sliceMask = rewriter.
create<vector::CreateMaskOp>(
1144 loc, sliceType.clone(rewriter.
getI1Type()),
1148 Value initDest = isTensorOp(writeOp) ? writeOp.getSource() :
Value{};
1150 auto result = rewriter.
create<scf::ForOp>(
1151 loc, lb, ub, step, initLoopArgs,
1159 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) ->
Value {
1160 return b.create<vector::ExtractOp>(
1161 loc, transposeSourceSlices[idx], iv);
1163 auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
1168 loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front();
1169 auto newWriteOp = b.create<vector::TransferWriteOp>(
1170 loc, sliceVec, dest, xferIndices,
1173 newWriteOp.getMaskMutable().assign(sliceMask);
1176 b.create<scf::YieldOp>(loc, loopIterArgs.empty()
1178 : newWriteOp.getResult());
1181 if (isTensorOp(writeOp))
1196 template <
typename OpTy>
1197 static void maybeAssignMask(
OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1199 if (!xferOp.getMask())
1202 if (xferOp.isBroadcastDim(0)) {
1205 newXferOp.getMaskMutable().assign(xferOp.getMask());
1209 if (xferOp.getMaskType().getRank() > 1) {
1216 auto newMask = b.
create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
1217 newXferOp.getMaskMutable().assign(newMask);
1253 struct UnrollTransferReadConversion
1254 :
public VectorToSCFPattern<TransferReadOp> {
1255 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1260 setHasBoundedRewriteRecursion();
1266 TransferReadOp xferOp)
const {
1267 if (
auto insertOp = getInsertOp(xferOp))
1268 return insertOp.getDest();
1270 return rewriter.
create<vector::SplatOp>(loc, xferOp.getVectorType(),
1271 xferOp.getPadding());
1276 vector::InsertOp getInsertOp(TransferReadOp xferOp)
const {
1277 if (xferOp->hasOneUse()) {
1279 if (
auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1283 return vector::InsertOp();
1288 void getInsertionIndices(TransferReadOp xferOp,
1290 if (
auto insertOp = getInsertOp(xferOp)) {
1291 auto pos = insertOp.getMixedPosition();
1292 indices.append(pos.begin(), pos.end());
1298 LogicalResult matchAndRewrite(TransferReadOp xferOp,
1300 if (xferOp.getVectorType().getRank() <=
options.targetRank)
1302 xferOp,
"vector rank is less or equal to target rank");
1303 if (failed(checkLowerTensors(xferOp, rewriter)))
1305 if (xferOp.getVectorType().getElementType() !=
1306 xferOp.getShapedType().getElementType())
1308 xferOp,
"not yet supported: element type mismatch");
1309 auto xferVecType = xferOp.getVectorType();
1310 if (xferVecType.getScalableDims()[0]) {
1312 xferOp,
"scalable dimensions cannot be unrolled at compile time");
1315 auto insertOp = getInsertOp(xferOp);
1316 auto vec = buildResultVector(rewriter, xferOp);
1317 auto vecType = dyn_cast<VectorType>(vec.getType());
1321 int64_t dimSize = xferVecType.getShape()[0];
1325 for (int64_t i = 0; i < dimSize; ++i) {
1326 Value iv = rewriter.
create<arith::ConstantIndexOp>(loc, i);
1328 vec = generateInBoundsCheck(
1329 rewriter, xferOp, iv, unpackedDim(xferOp),
TypeRange(vecType),
1338 getInsertionIndices(xferOp, insertionIndices);
1341 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1342 auto newXferOp = b.
create<vector::TransferReadOp>(
1343 loc, newXferVecType, xferOp.getSource(), xferIndices,
1345 xferOp.getPadding(),
Value(), inBoundsAttr);
1346 maybeAssignMask(b, xferOp, newXferOp, i);
1347 return b.
create<vector::InsertOp>(loc, newXferOp, vec,
1395 struct UnrollTransferWriteConversion
1396 :
public VectorToSCFPattern<TransferWriteOp> {
1397 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1402 setHasBoundedRewriteRecursion();
1406 Value getDataVector(TransferWriteOp xferOp)
const {
1407 if (
auto extractOp = getExtractOp(xferOp))
1408 return extractOp.getVector();
1409 return xferOp.getVector();
1413 vector::ExtractOp getExtractOp(TransferWriteOp xferOp)
const {
1414 if (
auto *op = xferOp.getVector().getDefiningOp())
1415 return dyn_cast<vector::ExtractOp>(op);
1416 return vector::ExtractOp();
1421 void getExtractionIndices(TransferWriteOp xferOp,
1423 if (
auto extractOp = getExtractOp(xferOp)) {
1424 auto pos = extractOp.getMixedPosition();
1425 indices.append(pos.begin(), pos.end());
1431 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1433 VectorType inputVectorTy = xferOp.getVectorType();
1435 if (inputVectorTy.getRank() <=
options.targetRank)
1438 if (failed(checkLowerTensors(xferOp, rewriter)))
1441 if (inputVectorTy.getElementType() !=
1442 xferOp.getShapedType().getElementType())
1445 auto vec = getDataVector(xferOp);
1446 if (inputVectorTy.getScalableDims()[0]) {
1451 int64_t dimSize = inputVectorTy.getShape()[0];
1452 Value source = xferOp.getSource();
1453 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
1457 for (int64_t i = 0; i < dimSize; ++i) {
1458 Value iv = rewriter.
create<arith::ConstantIndexOp>(loc, i);
1460 auto updatedSource = generateInBoundsCheck(
1461 rewriter, xferOp, iv, unpackedDim(xferOp),
1471 getExtractionIndices(xferOp, extractionIndices);
1475 b.
create<vector::ExtractOp>(loc, vec, extractionIndices);
1476 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1478 if (inputVectorTy.getRank() == 1) {
1482 xferVec = b.
create<vector::BroadcastOp>(
1485 xferVec = extracted;
1487 auto newXferOp = b.
create<vector::TransferWriteOp>(
1488 loc, sourceType, xferVec, source, xferIndices,
1492 maybeAssignMask(b, xferOp, newXferOp, i);
1494 return isTensorOp(xferOp) ? newXferOp->getResult(0) :
Value();
1498 return isTensorOp(xferOp) ? source :
Value();
1501 if (isTensorOp(xferOp))
1502 source = updatedSource;
1505 if (isTensorOp(xferOp))
1522 template <
typename OpTy>
1523 static std::optional<int64_t>
1526 auto indices = xferOp.getIndices();
1527 auto map = xferOp.getPermutationMap();
1528 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
1530 memrefIndices.append(indices.begin(), indices.end());
1531 assert(map.getNumResults() == 1 &&
1532 "Expected 1 permutation map result for 1D transfer");
1533 if (
auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1535 auto dim = expr.getPosition();
1537 bindDims(xferOp.getContext(), d0, d1);
1538 Value offset = memrefIndices[dim];
1539 memrefIndices[dim] =
1544 assert(xferOp.isBroadcastDim(0) &&
1545 "Expected AffineDimExpr or AffineConstantExpr");
1546 return std::nullopt;
1551 template <
typename OpTy>
1556 struct Strategy1d<TransferReadOp> {
1558 TransferReadOp xferOp,
Value iv,
1561 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1562 auto vec = loopState[0];
1566 auto nextVec = generateInBoundsCheck(
1567 b, xferOp, iv, dim,
TypeRange(xferOp.getVectorType()),
1571 b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
1572 return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1576 b.
create<scf::YieldOp>(loc, nextVec);
1579 static Value initialLoopState(
OpBuilder &b, TransferReadOp xferOp) {
1582 return b.
create<vector::SplatOp>(loc, xferOp.getVectorType(),
1583 xferOp.getPadding());
1589 struct Strategy1d<TransferWriteOp> {
1591 TransferWriteOp xferOp,
Value iv,
1594 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1597 generateInBoundsCheck(
1601 b.
create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1602 b.
create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
1604 b.
create<scf::YieldOp>(loc);
1607 static Value initialLoopState(
OpBuilder &b, TransferWriteOp xferOp) {
1643 template <
typename OpTy>
1644 struct TransferOp1dConversion :
public VectorToSCFPattern<OpTy> {
1645 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1647 LogicalResult matchAndRewrite(OpTy xferOp,
1650 if (xferOp.getTransferRank() == 0)
1652 auto map = xferOp.getPermutationMap();
1653 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1657 if (xferOp.getVectorType().getRank() != 1)
1659 if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
1664 auto vecType = xferOp.getVectorType();
1665 auto lb = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1667 rewriter.
create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1668 if (vecType.isScalable()) {
1671 ub = rewriter.
create<arith::MulIOp>(loc, ub, vscale);
1673 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
1674 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1680 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1693 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1694 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1697 patterns.add<lowering_n_d::PrepareTransferReadConversion,
1698 lowering_n_d::PrepareTransferWriteConversion,
1699 lowering_n_d::TransferOpConversion<TransferReadOp>,
1700 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1704 patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1707 if (
options.targetRank == 1) {
1708 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1709 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1712 patterns.add<lowering_n_d::DecomposePrintOpConversion>(
patterns.getContext(),
1718 struct ConvertVectorToSCFPass
1719 :
public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1720 ConvertVectorToSCFPass() =
default;
1722 this->fullUnroll =
options.unroll;
1723 this->targetRank =
options.targetRank;
1724 this->lowerTensors =
options.lowerTensors;
1725 this->lowerScalable =
options.lowerScalable;
1728 void runOnOperation()
override {
1731 options.targetRank = targetRank;
1732 options.lowerTensors = lowerTensors;
1733 options.lowerScalable = lowerScalable;
1738 lowerTransferPatterns);
1740 std::move(lowerTransferPatterns));
1750 std::unique_ptr<Pass>
1752 return std::make_unique<ConvertVectorToSCFPass>(
options);
MLIR_CRUNNERUTILS_EXPORT void printClose()
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static Operation * getAutomaticAllocationScope(Operation *op)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A trait of region holding operations that define a new scope for automatic allocations,...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Create a pass to convert a subset of vector ops to SCF.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...