29 #include "llvm/ADT/STLExtras.h"
32 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
33 #include "mlir/Conversion/Passes.h.inc"
37 using vector::TransferReadOp;
38 using vector::TransferWriteOp;
43 static const char kPassLabel[] =
"__vector_to_scf_lowering__";
46 static bool isTensorOp(VectorTransferOpInterface xferOp) {
47 if (isa<RankedTensorType>(xferOp.getShapedType())) {
48 if (isa<vector::TransferWriteOp>(xferOp)) {
50 assert(xferOp->getNumResults() > 0);
59 template <
typename OpTy>
65 LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
67 if (isTensorOp(xferOp) && !
options.lowerTensors) {
69 xferOp,
"lowering tensor transfers is disabled");
80 template <
typename OpTy>
81 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
83 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
84 auto map = xferOp.getPermutationMap();
85 if (
auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
86 return expr.getPosition();
88 assert(xferOp.isBroadcastDim(0) &&
89 "Expected AffineDimExpr or AffineConstantExpr");
96 template <
typename OpTy>
99 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
100 auto map = xferOp.getPermutationMap();
101 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
111 template <
typename OpTy>
114 typename OpTy::Adaptor adaptor(xferOp);
116 auto dim = unpackedDim(xferOp);
117 auto prevIndices = adaptor.getIndices();
118 indices.append(prevIndices.begin(), prevIndices.end());
121 bool isBroadcast = !dim.has_value();
124 bindDims(xferOp.getContext(), d0, d1);
125 Value offset = adaptor.getIndices()[*dim];
134 assert(value &&
"Expected non-empty value");
135 scf::YieldOp::create(b, loc, value);
137 scf::YieldOp::create(b, loc);
147 template <
typename OpTy>
149 if (!xferOp.getMask())
151 if (xferOp.getMaskType().getRank() != 1)
153 if (xferOp.isBroadcastDim(0))
157 return vector::ExtractOp::create(b, loc, xferOp.getMask(), iv);
184 template <
typename OpTy>
185 static Value generateInBoundsCheck(
190 bool hasRetVal = !resultTypes.empty();
194 bool isBroadcast = !dim;
197 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
200 bindDims(xferOp.getContext(), d0, d1);
201 Value base = xferOp.getIndices()[*dim];
204 cond = arith::CmpIOp::create(lb, arith::CmpIPredicate::sgt, memrefDim,
209 if (
auto maskCond = generateMaskCheck(b, xferOp, iv)) {
211 cond = arith::AndIOp::create(lb, cond, maskCond);
218 auto check = scf::IfOp::create(
222 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
226 if (outOfBoundsCase) {
227 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
229 scf::YieldOp::create(b, loc);
233 return hasRetVal ? check.getResult(0) :
Value();
237 return inBoundsCase(b, loc);
242 template <
typename OpTy>
243 static void generateInBoundsCheck(
247 generateInBoundsCheck(
251 inBoundsCase(b, loc);
257 outOfBoundsCase(b, loc);
263 static ArrayAttr dropFirstElem(
OpBuilder &b, ArrayAttr attr) {
271 template <
typename OpTy>
272 static void maybeApplyPassLabel(
OpBuilder &b, OpTy newXferOp,
273 unsigned targetRank) {
274 if (newXferOp.getVectorType().getRank() > targetRank)
281 struct BufferAllocs {
290 assert(scope &&
"Expected op to be inside automatic allocation scope");
295 template <
typename OpTy>
296 static BufferAllocs allocBuffers(
OpBuilder &b, OpTy xferOp) {
301 "AutomaticAllocationScope with >1 regions");
306 result.dataBuffer = memref::AllocaOp::create(b, loc, bufferType);
308 if (xferOp.getMask()) {
310 auto maskBuffer = memref::AllocaOp::create(b, loc, maskType);
312 memref::StoreOp::create(b, loc, xferOp.getMask(), maskBuffer);
314 memref::LoadOp::create(b, loc, maskBuffer,
ValueRange());
324 static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
325 auto vectorType = dyn_cast<VectorType>(type.getElementType());
328 if (vectorType.getScalableDims().front())
330 auto memrefShape = type.getShape();
332 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
333 newMemrefShape.push_back(vectorType.getDimSize(0));
340 template <
typename OpTy>
341 static Value getMaskBuffer(OpTy xferOp) {
342 assert(xferOp.getMask() &&
"Expected that transfer op has mask");
343 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
344 assert(loadOp &&
"Expected transfer op mask produced by LoadOp");
345 return loadOp.getMemRef();
349 template <
typename OpTy>
354 struct Strategy<TransferReadOp> {
357 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
358 assert(xferOp->hasOneUse() &&
"Expected exactly one use of TransferReadOp");
359 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
360 assert(storeOp &&
"Expected TransferReadOp result used by StoreOp");
372 return getStoreOp(xferOp).getMemRef();
376 static void getBufferIndices(TransferReadOp xferOp,
378 auto storeOp = getStoreOp(xferOp);
379 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
380 indices.append(prevIndices.begin(), prevIndices.end());
410 static TransferReadOp rewriteOp(
OpBuilder &b,
412 TransferReadOp xferOp,
Value buffer,
Value iv,
415 getBufferIndices(xferOp, storeIndices);
416 storeIndices.push_back(iv);
422 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
423 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
424 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
425 auto newXferOp = vector::TransferReadOp::create(
426 b, loc, vecType, xferOp.getBase(), xferIndices,
428 xferOp.getPadding(),
Value(), inBoundsAttr);
430 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
432 memref::StoreOp::create(b, loc, newXferOp.getVector(), buffer,
439 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferReadOp xferOp,
443 getBufferIndices(xferOp, storeIndices);
444 storeIndices.push_back(iv);
447 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
448 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
450 vector::BroadcastOp::create(b, loc, vecType, xferOp.getPadding());
451 memref::StoreOp::create(b, loc, vec, buffer, storeIndices);
459 rewriter.
eraseOp(getStoreOp(xferOp));
464 static Value initialLoopState(TransferReadOp xferOp) {
return Value(); }
469 struct Strategy<TransferWriteOp> {
478 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
479 assert(loadOp &&
"Expected transfer op vector produced by LoadOp");
480 return loadOp.getMemRef();
484 static void getBufferIndices(TransferWriteOp xferOp,
486 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
487 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
488 indices.append(prevIndices.begin(), prevIndices.end());
500 static TransferWriteOp rewriteOp(
OpBuilder &b,
502 TransferWriteOp xferOp,
Value buffer,
505 getBufferIndices(xferOp, loadIndices);
506 loadIndices.push_back(iv);
512 auto vec = memref::LoadOp::create(b, loc, buffer, loadIndices);
513 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
514 auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
515 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
516 auto newXferOp = vector::TransferWriteOp::create(
517 b, loc, type, vec, source, xferIndices,
521 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
527 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferWriteOp xferOp,
530 return isTensorOp(xferOp) ? loopState[0] :
Value();
536 if (isTensorOp(xferOp)) {
537 assert(forOp->getNumResults() == 1 &&
"Expected one for loop result");
538 rewriter.
replaceOp(xferOp, forOp->getResult(0));
545 static Value initialLoopState(TransferWriteOp xferOp) {
546 return isTensorOp(xferOp) ? xferOp.getBase() :
Value();
550 template <
typename OpTy>
551 static LogicalResult checkPrepareXferOp(OpTy xferOp,
PatternRewriter &rewriter,
553 if (xferOp->hasAttr(kPassLabel))
555 xferOp,
"kPassLabel is present (vector-to-scf lowering in progress)");
556 if (xferOp.getVectorType().getRank() <=
options.targetRank)
558 xferOp,
"xferOp vector rank <= transformation target rank");
559 if (xferOp.getVectorType().getScalableDims().front())
561 xferOp,
"Unpacking of the leading dimension into the memref is not yet "
562 "supported for scalable dims");
563 if (isTensorOp(xferOp) && !
options.lowerTensors)
565 xferOp,
"Unpacking for tensors has been disabled.");
566 if (xferOp.getVectorType().getElementType() !=
567 xferOp.getShapedType().getElementType())
569 xferOp,
"Mismatching source and destination element types.");
597 struct PrepareTransferReadConversion
598 :
public VectorToSCFPattern<TransferReadOp> {
599 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
601 LogicalResult matchAndRewrite(TransferReadOp xferOp,
605 xferOp,
"checkPrepareXferOp conditions not met!");
607 auto buffers = allocBuffers(rewriter, xferOp);
608 auto *newXfer = rewriter.
clone(*xferOp.getOperation());
610 if (xferOp.getMask()) {
611 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
616 memref::StoreOp::create(rewriter, loc, newXfer->getResult(0),
647 struct PrepareTransferWriteConversion
648 :
public VectorToSCFPattern<TransferWriteOp> {
649 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
651 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
655 xferOp,
"checkPrepareXferOp conditions not met!");
658 auto buffers = allocBuffers(rewriter, xferOp);
659 memref::StoreOp::create(rewriter, loc, xferOp.getVector(),
661 auto loadedVec = memref::LoadOp::create(rewriter, loc, buffers.dataBuffer);
663 xferOp.getValueToStoreMutable().assign(loadedVec);
664 xferOp->setAttr(kPassLabel, rewriter.
getUnitAttr());
667 if (xferOp.getMask()) {
669 xferOp.getMaskMutable().assign(buffers.maskBuffer);
704 struct DecomposePrintOpConversion :
public VectorToSCFPattern<vector::PrintOp> {
705 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
706 LogicalResult matchAndRewrite(vector::PrintOp
printOp,
711 VectorType vectorType = dyn_cast<VectorType>(
printOp.getPrintType());
721 if (vectorType.getRank() > 1 && vectorType.isScalable())
725 auto value =
printOp.getSource();
727 if (
auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
731 auto width = intTy.getWidth();
732 auto legalWidth = llvm::NextPowerOf2(
std::max(8u, width) - 1);
734 intTy.getSignedness());
736 auto signlessSourceVectorType =
737 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
738 auto signlessTargetVectorType =
739 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
740 auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
741 value = vector::BitCastOp::create(rewriter, loc, signlessSourceVectorType,
743 if (value.
getType() != signlessTargetVectorType) {
744 if (width == 1 || intTy.isUnsigned())
745 value = arith::ExtUIOp::create(rewriter, loc,
746 signlessTargetVectorType, value);
748 value = arith::ExtSIOp::create(rewriter, loc,
749 signlessTargetVectorType, value);
751 value = vector::BitCastOp::create(rewriter, loc, targetVectorType, value);
752 vectorType = targetVectorType;
755 auto scalableDimensions = vectorType.getScalableDims();
756 auto shape = vectorType.getShape();
757 constexpr int64_t singletonShape[] = {1};
758 if (vectorType.getRank() == 0)
759 shape = singletonShape;
761 if (vectorType.getRank() != 1) {
764 int64_t flatLength = llvm::product_of(shape);
765 auto flatVectorType =
767 value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value);
770 vector::PrintOp firstClose;
772 for (
unsigned d = 0; d < shape.size(); d++) {
778 if (!scalableDimensions.empty() && scalableDimensions[d]) {
779 auto vscale = vector::VectorScaleOp::create(rewriter, loc,
781 upperBound = arith::MulIOp::create(rewriter, loc, upperBound, vscale);
783 auto lastIndex = arith::SubIOp::create(rewriter, loc, upperBound, step);
786 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
788 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
790 rewriter, loc, vector::PrintPunctuation::Close);
794 auto loopIdx = loop.getInductionVar();
795 loopIndices.push_back(loopIdx);
799 auto notLastIndex = arith::CmpIOp::create(
800 rewriter, loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
801 scf::IfOp::create(rewriter, loc, notLastIndex,
803 vector::PrintOp::create(
804 builder, loc, vector::PrintPunctuation::Comma);
805 scf::YieldOp::create(builder, loc);
814 auto currentStride = 1;
815 for (
int d = shape.size() - 1; d >= 0; d--) {
818 auto index = arith::MulIOp::create(rewriter, loc, stride, loopIndices[d]);
820 flatIndex = arith::AddIOp::create(rewriter, loc, flatIndex, index);
823 currentStride *= shape[d];
827 auto element = vector::ExtractOp::create(rewriter, loc, value, flatIndex);
828 vector::PrintOp::create(rewriter, loc, element,
829 vector::PrintPunctuation::NoPunctuation);
832 vector::PrintOp::create(rewriter, loc,
printOp.getPunctuation());
837 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
839 IntegerType::Signless);
872 template <
typename OpTy>
873 struct TransferOpConversion :
public VectorToSCFPattern<OpTy> {
874 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
879 this->setHasBoundedRewriteRecursion();
882 static void getMaskBufferLoadIndices(OpTy xferOp,
Value castedMaskBuffer,
885 assert(xferOp.getMask() &&
"Expected transfer op to have mask");
891 Value maskBuffer = getMaskBuffer(xferOp);
894 if (
auto loadOp = dyn_cast<memref::LoadOp>(user)) {
896 loadIndices.append(prevIndices.begin(), prevIndices.end());
903 if (!xferOp.isBroadcastDim(0))
904 loadIndices.push_back(iv);
907 LogicalResult matchAndRewrite(OpTy xferOp,
909 if (!xferOp->hasAttr(kPassLabel))
911 xferOp,
"kPassLabel is present (progressing lowering in progress)");
916 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.
getType());
917 FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
918 if (
failed(castedDataType))
920 "Failed to unpack one vector dim.");
922 auto castedDataBuffer =
923 vector::TypeCastOp::create(locB, *castedDataType, dataBuffer);
926 Value castedMaskBuffer;
927 if (xferOp.getMask()) {
928 Value maskBuffer = getMaskBuffer(xferOp);
929 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
935 castedMaskBuffer = maskBuffer;
939 auto maskBufferType = cast<MemRefType>(maskBuffer.
getType());
940 MemRefType castedMaskType = *unpackOneDim(maskBufferType);
942 vector::TypeCastOp::create(locB, castedMaskType, maskBuffer);
949 locB, castedDataType->getDimSize(castedDataType->getRank() - 1));
953 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
956 auto result = scf::ForOp::create(
961 auto result = generateInBoundsCheck(
962 b, xferOp, iv, unpackedDim(xferOp),
967 OpTy newXfer = Strategy<OpTy>::rewriteOp(
968 b, this->options, xferOp, castedDataBuffer, iv, loopState);
974 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
975 xferOp.getMaskType().getRank() > 1)) {
980 getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
982 auto mask = memref::LoadOp::create(b, loc, castedMaskBuffer,
985 newXfer.getMaskMutable().assign(mask);
989 return loopState.empty() ?
Value() : newXfer->getResult(0);
993 return Strategy<OpTy>::handleOutOfBoundsDim(
994 b, xferOp, castedDataBuffer, iv, loopState);
997 maybeYieldValue(b, loc, !loopState.empty(), result);
1000 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
1007 template <
typename VscaleConstantBuilder>
1008 static FailureOr<SmallVector<OpFoldResult>>
1009 getMaskDimSizes(
Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1012 if (
auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>()) {
1013 return llvm::map_to_vector(createMaskOp.getOperands(), [](
Value dimSize) {
1014 return OpFoldResult(dimSize);
1017 if (
auto constantMask = mask.
getDefiningOp<vector::ConstantMaskOp>()) {
1019 VectorType maskType = constantMask.getVectorType();
1021 return llvm::map_to_vector(
1022 constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1024 if (maskType.getScalableDims()[dimIdx++])
1025 return OpFoldResult(createVscaleMultiple(dimSize));
1026 return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1069 struct ScalableTransposeTransferWriteConversion
1070 : VectorToSCFPattern<vector::TransferWriteOp> {
1071 using VectorToSCFPattern::VectorToSCFPattern;
1073 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1075 if (
failed(checkLowerTensors(writeOp, rewriter)))
1078 VectorType vectorType = writeOp.getVectorType();
1085 writeOp,
"expected vector of the form vector<[N]xMxty>");
1088 auto permutationMap = writeOp.getPermutationMap();
1089 if (!permutationMap.isIdentity()) {
1091 writeOp,
"non-identity permutations are unsupported (lower first)");
1097 if (!writeOp.isDimInBounds(0)) {
1099 writeOp,
"out-of-bounds dims are unsupported (use masking)");
1102 Value vector = writeOp.getVector();
1103 auto transposeOp = vector.
getDefiningOp<vector::TransposeOp>();
1109 auto loc = writeOp.getLoc();
1110 auto createVscaleMultiple =
1113 auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1116 "failed to resolve mask dims");
1119 int64_t fixedDimSize = vectorType.getDimSize(1);
1120 auto fixedDimOffsets = llvm::seq(fixedDimSize);
1123 auto transposeSource = transposeOp.getVector();
1125 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) ->
Value {
1126 return vector::ExtractOp::create(rewriter, loc, transposeSource, idx);
1133 ?
Value(createVscaleMultiple(vectorType.getDimSize(0)))
1139 Value sliceMask =
nullptr;
1140 if (!maskDims->empty()) {
1141 sliceMask = vector::CreateMaskOp::create(
1142 rewriter, loc, sliceType.clone(rewriter.
getI1Type()),
1146 Value initDest = isTensorOp(writeOp) ? writeOp.getBase() :
Value{};
1148 auto result = scf::ForOp::create(
1149 rewriter, loc, lb, ub, step, initLoopArgs,
1157 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) ->
Value {
1158 return vector::ExtractOp::create(
1159 b, loc, transposeSourceSlices[idx], iv);
1161 auto sliceVec = vector::FromElementsOp::create(b, loc, sliceType,
1166 loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
1167 auto newWriteOp = vector::TransferWriteOp::create(
1168 b, loc, sliceVec, dest, xferIndices,
1171 newWriteOp.getMaskMutable().assign(sliceMask);
1174 scf::YieldOp::create(b, loc,
1176 : newWriteOp.getResult());
1179 if (isTensorOp(writeOp))
1194 template <
typename OpTy>
1195 static void maybeAssignMask(
OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1197 if (!xferOp.getMask())
1200 if (xferOp.isBroadcastDim(0)) {
1203 newXferOp.getMaskMutable().assign(xferOp.getMask());
1207 if (xferOp.getMaskType().getRank() > 1) {
1214 auto newMask = vector::ExtractOp::create(b, loc, xferOp.getMask(), indices);
1215 newXferOp.getMaskMutable().assign(newMask);
1251 struct UnrollTransferReadConversion
1252 :
public VectorToSCFPattern<TransferReadOp> {
1253 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1258 setHasBoundedRewriteRecursion();
1264 TransferReadOp xferOp)
const {
1265 if (
auto insertOp = getInsertOp(xferOp))
1266 return insertOp.getDest();
1268 return vector::BroadcastOp::create(rewriter, loc, xferOp.getVectorType(),
1269 xferOp.getPadding());
1274 vector::InsertOp getInsertOp(TransferReadOp xferOp)
const {
1275 if (xferOp->hasOneUse()) {
1277 if (
auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1281 return vector::InsertOp();
1286 void getInsertionIndices(TransferReadOp xferOp,
1288 if (
auto insertOp = getInsertOp(xferOp)) {
1289 auto pos = insertOp.getMixedPosition();
1290 indices.append(pos.begin(), pos.end());
1296 LogicalResult matchAndRewrite(TransferReadOp xferOp,
1298 if (xferOp.getVectorType().getRank() <=
options.targetRank)
1300 xferOp,
"vector rank is less or equal to target rank");
1301 if (
failed(checkLowerTensors(xferOp, rewriter)))
1303 if (xferOp.getVectorType().getElementType() !=
1304 xferOp.getShapedType().getElementType())
1306 xferOp,
"not yet supported: element type mismatch");
1307 auto xferVecType = xferOp.getVectorType();
1308 if (xferVecType.getScalableDims()[0]) {
1310 xferOp,
"scalable dimensions cannot be unrolled at compile time");
1313 auto insertOp = getInsertOp(xferOp);
1314 auto vec = buildResultVector(rewriter, xferOp);
1315 auto vecType = dyn_cast<VectorType>(vec.getType());
1319 int64_t dimSize = xferVecType.getShape()[0];
1323 for (int64_t i = 0; i < dimSize; ++i) {
1328 vec = generateInBoundsCheck(
1329 rewriter, xferOp, iv, unpackedDim(xferOp),
TypeRange(vecType),
1338 getInsertionIndices(xferOp, insertionIndices);
1341 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1343 auto newXferOp = vector::TransferReadOp::create(
1344 b, loc, newXferVecType, xferOp.getBase(), xferIndices,
1346 xferOp.getPadding(),
Value(), inBoundsAttr);
1347 maybeAssignMask(b, xferOp, newXferOp, i);
1349 Value valToInser = newXferOp.getResult();
1350 if (newXferVecType.getRank() == 0) {
1353 valToInser = vector::ExtractOp::create(b, loc, valToInser,
1356 return vector::InsertOp::create(b, loc, valToInser, vec,
1404 struct UnrollTransferWriteConversion
1405 :
public VectorToSCFPattern<TransferWriteOp> {
1406 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1411 setHasBoundedRewriteRecursion();
1415 Value getDataVector(TransferWriteOp xferOp)
const {
1416 if (
auto extractOp = getExtractOp(xferOp))
1417 return extractOp.getSource();
1418 return xferOp.getVector();
1422 vector::ExtractOp getExtractOp(TransferWriteOp xferOp)
const {
1423 if (
auto *op = xferOp.getVector().getDefiningOp())
1424 return dyn_cast<vector::ExtractOp>(op);
1425 return vector::ExtractOp();
1430 void getExtractionIndices(TransferWriteOp xferOp,
1432 if (
auto extractOp = getExtractOp(xferOp)) {
1433 auto pos = extractOp.getMixedPosition();
1434 indices.append(pos.begin(), pos.end());
1440 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1442 VectorType inputVectorTy = xferOp.getVectorType();
1444 if (inputVectorTy.getRank() <=
options.targetRank)
1447 if (
failed(checkLowerTensors(xferOp, rewriter)))
1450 if (inputVectorTy.getElementType() !=
1451 xferOp.getShapedType().getElementType())
1454 auto vec = getDataVector(xferOp);
1455 if (inputVectorTy.getScalableDims()[0]) {
1460 int64_t dimSize = inputVectorTy.getShape()[0];
1461 Value source = xferOp.getBase();
1462 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
1466 for (int64_t i = 0; i < dimSize; ++i) {
1469 auto updatedSource = generateInBoundsCheck(
1470 rewriter, xferOp, iv, unpackedDim(xferOp),
1480 getExtractionIndices(xferOp, extractionIndices);
1484 vector::ExtractOp::create(b, loc, vec, extractionIndices);
1485 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1487 if (inputVectorTy.getRank() == 1) {
1491 xferVec = vector::BroadcastOp::create(
1494 xferVec = extracted;
1496 auto newXferOp = vector::TransferWriteOp::create(
1497 b, loc, sourceType, xferVec, source, xferIndices,
1501 maybeAssignMask(b, xferOp, newXferOp, i);
1503 return isTensorOp(xferOp) ? newXferOp->getResult(0) :
Value();
1507 return isTensorOp(xferOp) ? source :
Value();
1510 if (isTensorOp(xferOp))
1511 source = updatedSource;
1514 if (isTensorOp(xferOp))
1531 template <
typename OpTy>
1532 static std::optional<int64_t>
1535 auto indices = xferOp.getIndices();
1536 auto map = xferOp.getPermutationMap();
1537 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
1539 memrefIndices.append(indices.begin(), indices.end());
1540 assert(map.getNumResults() == 1 &&
1541 "Expected 1 permutation map result for 1D transfer");
1542 if (
auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1544 auto dim = expr.getPosition();
1546 bindDims(xferOp.getContext(), d0, d1);
1547 Value offset = memrefIndices[dim];
1548 memrefIndices[dim] =
1553 assert(xferOp.isBroadcastDim(0) &&
1554 "Expected AffineDimExpr or AffineConstantExpr");
1555 return std::nullopt;
1560 template <
typename OpTy>
1565 struct Strategy1d<TransferReadOp> {
1567 TransferReadOp xferOp,
Value iv,
1570 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1571 auto vec = loopState[0];
1575 auto nextVec = generateInBoundsCheck(
1576 b, xferOp, iv, dim,
TypeRange(xferOp.getVectorType()),
1579 Value val = memref::LoadOp::create(b, loc, xferOp.getBase(), indices);
1580 return vector::InsertOp::create(b, loc, val, vec, iv);
1584 scf::YieldOp::create(b, loc, nextVec);
1587 static Value initialLoopState(
OpBuilder &b, TransferReadOp xferOp) {
1590 return vector::BroadcastOp::create(b, loc, xferOp.getVectorType(),
1591 xferOp.getPadding());
1597 struct Strategy1d<TransferWriteOp> {
1599 TransferWriteOp xferOp,
Value iv,
1602 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1605 generateInBoundsCheck(
1608 auto val = vector::ExtractOp::create(b, loc, xferOp.getVector(), iv);
1609 memref::StoreOp::create(b, loc, val, xferOp.getBase(), indices);
1611 scf::YieldOp::create(b, loc);
1614 static Value initialLoopState(
OpBuilder &b, TransferWriteOp xferOp) {
1650 template <
typename OpTy>
1651 struct TransferOp1dConversion :
public VectorToSCFPattern<OpTy> {
1652 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1654 LogicalResult matchAndRewrite(OpTy xferOp,
1657 if (xferOp.getTransferRank() == 0)
1659 auto map = xferOp.getPermutationMap();
1660 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1664 if (xferOp.getVectorType().getRank() != 1)
1666 if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
1671 auto vecType = xferOp.getVectorType();
1675 if (vecType.isScalable()) {
1677 vector::VectorScaleOp::create(rewriter, loc, rewriter.
getIndexType());
1678 ub = arith::MulIOp::create(rewriter, loc, ub, vscale);
1681 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1687 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1700 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1701 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1704 patterns.add<lowering_n_d::PrepareTransferReadConversion,
1705 lowering_n_d::PrepareTransferWriteConversion,
1706 lowering_n_d::TransferOpConversion<TransferReadOp>,
1707 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1711 patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1714 if (
options.targetRank == 1) {
1715 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1716 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1719 patterns.add<lowering_n_d::DecomposePrintOpConversion>(
patterns.getContext(),
1725 struct ConvertVectorToSCFPass
1726 :
public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1727 ConvertVectorToSCFPass() =
default;
1729 this->fullUnroll =
options.unroll;
1730 this->targetRank =
options.targetRank;
1731 this->lowerTensors =
options.lowerTensors;
1732 this->lowerScalable =
options.lowerScalable;
1735 void runOnOperation()
override {
1738 options.targetRank = targetRank;
1739 options.lowerTensors = lowerTensors;
1740 options.lowerScalable = lowerScalable;
1745 lowerTransferPatterns);
1747 std::move(lowerTransferPatterns));
1757 std::unique_ptr<Pass>
1759 return std::make_unique<ConvertVectorToSCFPass>(
options);
MLIR_CRUNNERUTILS_EXPORT void printClose()
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static Operation * getAutomaticAllocationScope(Operation *op)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A trait of region holding operations that define a new scope for automatic allocations,...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Create a pass to convert a subset of vector ops to SCF.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...