31 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
32 #include "mlir/Conversion/Passes.h.inc"
36 using vector::TransferReadOp;
37 using vector::TransferWriteOp;
42 static const char kPassLabel[] =
"__vector_to_scf_lowering__";
45 static bool isTensorOp(VectorTransferOpInterface xferOp) {
46 if (isa<RankedTensorType>(xferOp.getShapedType())) {
47 if (isa<vector::TransferWriteOp>(xferOp)) {
49 assert(xferOp->getNumResults() > 0);
58 template <
typename OpTy>
64 LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
66 if (isTensorOp(xferOp) && !
options.lowerTensors) {
68 xferOp,
"lowering tensor transfers is disabled");
79 template <
typename OpTy>
80 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
82 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
83 auto map = xferOp.getPermutationMap();
84 if (
auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
85 return expr.getPosition();
87 assert(xferOp.isBroadcastDim(0) &&
88 "Expected AffineDimExpr or AffineConstantExpr");
95 template <
typename OpTy>
98 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
99 auto map = xferOp.getPermutationMap();
100 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
110 template <
typename OpTy>
113 typename OpTy::Adaptor adaptor(xferOp);
115 auto dim = unpackedDim(xferOp);
116 auto prevIndices = adaptor.getIndices();
117 indices.append(prevIndices.begin(), prevIndices.end());
120 bool isBroadcast = !dim.has_value();
123 bindDims(xferOp.getContext(), d0, d1);
124 Value offset = adaptor.getIndices()[*dim];
133 assert(value &&
"Expected non-empty value");
134 scf::YieldOp::create(b, loc, value);
136 scf::YieldOp::create(b, loc);
146 template <
typename OpTy>
148 if (!xferOp.getMask())
150 if (xferOp.getMaskType().getRank() != 1)
152 if (xferOp.isBroadcastDim(0))
156 return vector::ExtractOp::create(b, loc, xferOp.getMask(), iv);
183 template <
typename OpTy>
184 static Value generateInBoundsCheck(
189 bool hasRetVal = !resultTypes.empty();
193 bool isBroadcast = !dim;
196 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
199 bindDims(xferOp.getContext(), d0, d1);
200 Value base = xferOp.getIndices()[*dim];
203 cond = arith::CmpIOp::create(lb, arith::CmpIPredicate::sgt, memrefDim,
208 if (
auto maskCond = generateMaskCheck(b, xferOp, iv)) {
210 cond = arith::AndIOp::create(lb, cond, maskCond);
217 auto check = scf::IfOp::create(
221 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
225 if (outOfBoundsCase) {
226 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
228 scf::YieldOp::create(b, loc);
232 return hasRetVal ? check.getResult(0) :
Value();
236 return inBoundsCase(b, loc);
241 template <
typename OpTy>
242 static void generateInBoundsCheck(
246 generateInBoundsCheck(
250 inBoundsCase(b, loc);
256 outOfBoundsCase(b, loc);
262 static ArrayAttr dropFirstElem(
OpBuilder &b, ArrayAttr attr) {
270 template <
typename OpTy>
271 static void maybeApplyPassLabel(
OpBuilder &b, OpTy newXferOp,
272 unsigned targetRank) {
273 if (newXferOp.getVectorType().getRank() > targetRank)
280 struct BufferAllocs {
289 assert(scope &&
"Expected op to be inside automatic allocation scope");
294 template <
typename OpTy>
295 static BufferAllocs allocBuffers(
OpBuilder &b, OpTy xferOp) {
300 "AutomaticAllocationScope with >1 regions");
305 result.dataBuffer = memref::AllocaOp::create(b, loc, bufferType);
307 if (xferOp.getMask()) {
309 auto maskBuffer = memref::AllocaOp::create(b, loc, maskType);
311 memref::StoreOp::create(b, loc, xferOp.getMask(), maskBuffer);
313 memref::LoadOp::create(b, loc, maskBuffer,
ValueRange());
323 static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
324 auto vectorType = dyn_cast<VectorType>(type.getElementType());
327 if (vectorType.getScalableDims().front())
329 auto memrefShape = type.getShape();
331 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
332 newMemrefShape.push_back(vectorType.getDimSize(0));
339 template <
typename OpTy>
340 static Value getMaskBuffer(OpTy xferOp) {
341 assert(xferOp.getMask() &&
"Expected that transfer op has mask");
342 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
343 assert(loadOp &&
"Expected transfer op mask produced by LoadOp");
344 return loadOp.getMemRef();
348 template <
typename OpTy>
353 struct Strategy<TransferReadOp> {
356 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
357 assert(xferOp->hasOneUse() &&
"Expected exactly one use of TransferReadOp");
358 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
359 assert(storeOp &&
"Expected TransferReadOp result used by StoreOp");
371 return getStoreOp(xferOp).getMemRef();
375 static void getBufferIndices(TransferReadOp xferOp,
377 auto storeOp = getStoreOp(xferOp);
378 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
379 indices.append(prevIndices.begin(), prevIndices.end());
409 static TransferReadOp rewriteOp(
OpBuilder &b,
411 TransferReadOp xferOp,
Value buffer,
Value iv,
414 getBufferIndices(xferOp, storeIndices);
415 storeIndices.push_back(iv);
421 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
422 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
423 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
424 auto newXferOp = vector::TransferReadOp::create(
425 b, loc, vecType, xferOp.getBase(), xferIndices,
427 xferOp.getPadding(),
Value(), inBoundsAttr);
429 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
431 memref::StoreOp::create(b, loc, newXferOp.getVector(), buffer,
438 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferReadOp xferOp,
442 getBufferIndices(xferOp, storeIndices);
443 storeIndices.push_back(iv);
446 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
447 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
449 vector::BroadcastOp::create(b, loc, vecType, xferOp.getPadding());
450 memref::StoreOp::create(b, loc, vec, buffer, storeIndices);
458 rewriter.
eraseOp(getStoreOp(xferOp));
463 static Value initialLoopState(TransferReadOp xferOp) {
return Value(); }
468 struct Strategy<TransferWriteOp> {
477 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
478 assert(loadOp &&
"Expected transfer op vector produced by LoadOp");
479 return loadOp.getMemRef();
483 static void getBufferIndices(TransferWriteOp xferOp,
485 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
486 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
487 indices.append(prevIndices.begin(), prevIndices.end());
499 static TransferWriteOp rewriteOp(
OpBuilder &b,
501 TransferWriteOp xferOp,
Value buffer,
504 getBufferIndices(xferOp, loadIndices);
505 loadIndices.push_back(iv);
511 auto vec = memref::LoadOp::create(b, loc, buffer, loadIndices);
512 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
513 auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
514 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
515 auto newXferOp = vector::TransferWriteOp::create(
516 b, loc, type, vec, source, xferIndices,
520 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
526 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferWriteOp xferOp,
529 return isTensorOp(xferOp) ? loopState[0] :
Value();
535 if (isTensorOp(xferOp)) {
536 assert(forOp->getNumResults() == 1 &&
"Expected one for loop result");
537 rewriter.
replaceOp(xferOp, forOp->getResult(0));
544 static Value initialLoopState(TransferWriteOp xferOp) {
545 return isTensorOp(xferOp) ? xferOp.getBase() :
Value();
549 template <
typename OpTy>
550 static LogicalResult checkPrepareXferOp(OpTy xferOp,
PatternRewriter &rewriter,
552 if (xferOp->hasAttr(kPassLabel))
554 xferOp,
"kPassLabel is present (vector-to-scf lowering in progress)");
555 if (xferOp.getVectorType().getRank() <=
options.targetRank)
557 xferOp,
"xferOp vector rank <= transformation target rank");
558 if (xferOp.getVectorType().getScalableDims().front())
560 xferOp,
"Unpacking of the leading dimension into the memref is not yet "
561 "supported for scalable dims");
562 if (isTensorOp(xferOp) && !
options.lowerTensors)
564 xferOp,
"Unpacking for tensors has been disabled.");
565 if (xferOp.getVectorType().getElementType() !=
566 xferOp.getShapedType().getElementType())
568 xferOp,
"Mismatching source and destination element types.");
596 struct PrepareTransferReadConversion
597 :
public VectorToSCFPattern<TransferReadOp> {
598 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
600 LogicalResult matchAndRewrite(TransferReadOp xferOp,
604 xferOp,
"checkPrepareXferOp conditions not met!");
606 auto buffers = allocBuffers(rewriter, xferOp);
607 auto *newXfer = rewriter.
clone(*xferOp.getOperation());
609 if (xferOp.getMask()) {
610 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
615 memref::StoreOp::create(rewriter, loc, newXfer->getResult(0),
646 struct PrepareTransferWriteConversion
647 :
public VectorToSCFPattern<TransferWriteOp> {
648 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
650 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
654 xferOp,
"checkPrepareXferOp conditions not met!");
657 auto buffers = allocBuffers(rewriter, xferOp);
658 memref::StoreOp::create(rewriter, loc, xferOp.getVector(),
660 auto loadedVec = memref::LoadOp::create(rewriter, loc, buffers.dataBuffer);
662 xferOp.getValueToStoreMutable().assign(loadedVec);
663 xferOp->setAttr(kPassLabel, rewriter.
getUnitAttr());
666 if (xferOp.getMask()) {
668 xferOp.getMaskMutable().assign(buffers.maskBuffer);
703 struct DecomposePrintOpConversion :
public VectorToSCFPattern<vector::PrintOp> {
704 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
705 LogicalResult matchAndRewrite(vector::PrintOp
printOp,
710 VectorType vectorType = dyn_cast<VectorType>(
printOp.getPrintType());
720 if (vectorType.getRank() > 1 && vectorType.isScalable())
724 auto value =
printOp.getSource();
726 if (
auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
730 auto width = intTy.getWidth();
731 auto legalWidth = llvm::NextPowerOf2(
std::max(8u, width) - 1);
733 intTy.getSignedness());
735 auto signlessSourceVectorType =
736 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
737 auto signlessTargetVectorType =
738 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
739 auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
740 value = vector::BitCastOp::create(rewriter, loc, signlessSourceVectorType,
742 if (value.
getType() != signlessTargetVectorType) {
743 if (width == 1 || intTy.isUnsigned())
744 value = arith::ExtUIOp::create(rewriter, loc,
745 signlessTargetVectorType, value);
747 value = arith::ExtSIOp::create(rewriter, loc,
748 signlessTargetVectorType, value);
750 value = vector::BitCastOp::create(rewriter, loc, targetVectorType, value);
751 vectorType = targetVectorType;
754 auto scalableDimensions = vectorType.getScalableDims();
755 auto shape = vectorType.getShape();
756 constexpr int64_t singletonShape[] = {1};
757 if (vectorType.getRank() == 0)
758 shape = singletonShape;
760 if (vectorType.getRank() != 1) {
763 auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
764 std::multiplies<int64_t>());
765 auto flatVectorType =
767 value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value);
770 vector::PrintOp firstClose;
772 for (
unsigned d = 0; d < shape.size(); d++) {
778 if (!scalableDimensions.empty() && scalableDimensions[d]) {
779 auto vscale = vector::VectorScaleOp::create(rewriter, loc,
781 upperBound = arith::MulIOp::create(rewriter, loc, upperBound, vscale);
783 auto lastIndex = arith::SubIOp::create(rewriter, loc, upperBound, step);
786 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
788 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
790 rewriter, loc, vector::PrintPunctuation::Close);
794 auto loopIdx = loop.getInductionVar();
795 loopIndices.push_back(loopIdx);
799 auto notLastIndex = arith::CmpIOp::create(
800 rewriter, loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
801 scf::IfOp::create(rewriter, loc, notLastIndex,
803 vector::PrintOp::create(
804 builder, loc, vector::PrintPunctuation::Comma);
805 scf::YieldOp::create(builder, loc);
814 auto currentStride = 1;
815 for (
int d = shape.size() - 1; d >= 0; d--) {
818 auto index = arith::MulIOp::create(rewriter, loc, stride, loopIndices[d]);
820 flatIndex = arith::AddIOp::create(rewriter, loc, flatIndex, index);
823 currentStride *= shape[d];
827 auto element = vector::ExtractOp::create(rewriter, loc, value, flatIndex);
828 vector::PrintOp::create(rewriter, loc, element,
829 vector::PrintPunctuation::NoPunctuation);
832 vector::PrintOp::create(rewriter, loc,
printOp.getPunctuation());
837 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
839 IntegerType::Signless);
872 template <
typename OpTy>
873 struct TransferOpConversion :
public VectorToSCFPattern<OpTy> {
874 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
879 this->setHasBoundedRewriteRecursion();
882 static void getMaskBufferLoadIndices(OpTy xferOp,
Value castedMaskBuffer,
885 assert(xferOp.getMask() &&
"Expected transfer op to have mask");
891 Value maskBuffer = getMaskBuffer(xferOp);
894 if (
auto loadOp = dyn_cast<memref::LoadOp>(user)) {
896 loadIndices.append(prevIndices.begin(), prevIndices.end());
903 if (!xferOp.isBroadcastDim(0))
904 loadIndices.push_back(iv);
907 LogicalResult matchAndRewrite(OpTy xferOp,
909 if (!xferOp->hasAttr(kPassLabel))
911 xferOp,
"kPassLabel is present (progressing lowering in progress)");
916 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.
getType());
917 FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
918 if (
failed(castedDataType))
920 "Failed to unpack one vector dim.");
922 auto castedDataBuffer =
923 vector::TypeCastOp::create(locB, *castedDataType, dataBuffer);
926 Value castedMaskBuffer;
927 if (xferOp.getMask()) {
928 Value maskBuffer = getMaskBuffer(xferOp);
929 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
935 castedMaskBuffer = maskBuffer;
939 auto maskBufferType = cast<MemRefType>(maskBuffer.
getType());
940 MemRefType castedMaskType = *unpackOneDim(maskBufferType);
942 vector::TypeCastOp::create(locB, castedMaskType, maskBuffer);
949 locB, castedDataType->getDimSize(castedDataType->getRank() - 1));
953 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
956 auto result = scf::ForOp::create(
961 auto result = generateInBoundsCheck(
962 b, xferOp, iv, unpackedDim(xferOp),
967 OpTy newXfer = Strategy<OpTy>::rewriteOp(
968 b, this->options, xferOp, castedDataBuffer, iv, loopState);
974 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
975 xferOp.getMaskType().getRank() > 1)) {
980 getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
982 auto mask = memref::LoadOp::create(b, loc, castedMaskBuffer,
985 newXfer.getMaskMutable().assign(mask);
989 return loopState.empty() ?
Value() : newXfer->getResult(0);
993 return Strategy<OpTy>::handleOutOfBoundsDim(
994 b, xferOp, castedDataBuffer, iv, loopState);
997 maybeYieldValue(b, loc, !loopState.empty(), result);
1000 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
1007 template <
typename VscaleConstantBuilder>
1008 static FailureOr<SmallVector<OpFoldResult>>
1009 getMaskDimSizes(
Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1012 if (
auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>()) {
1013 return llvm::map_to_vector(createMaskOp.getOperands(), [](
Value dimSize) {
1014 return OpFoldResult(dimSize);
1017 if (
auto constantMask = mask.
getDefiningOp<vector::ConstantMaskOp>()) {
1019 VectorType maskType = constantMask.getVectorType();
1021 return llvm::map_to_vector(
1022 constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1024 if (maskType.getScalableDims()[dimIdx++])
1025 return OpFoldResult(createVscaleMultiple(dimSize));
1026 return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1069 struct ScalableTransposeTransferWriteConversion
1070 : VectorToSCFPattern<vector::TransferWriteOp> {
1071 using VectorToSCFPattern::VectorToSCFPattern;
1073 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1075 if (
failed(checkLowerTensors(writeOp, rewriter)))
1078 VectorType vectorType = writeOp.getVectorType();
1085 writeOp,
"expected vector of the form vector<[N]xMxty>");
1088 auto permutationMap = writeOp.getPermutationMap();
1089 if (!permutationMap.isIdentity()) {
1091 writeOp,
"non-identity permutations are unsupported (lower first)");
1097 if (!writeOp.isDimInBounds(0)) {
1099 writeOp,
"out-of-bounds dims are unsupported (use masking)");
1102 Value vector = writeOp.getVector();
1103 auto transposeOp = vector.
getDefiningOp<vector::TransposeOp>();
1109 auto loc = writeOp.getLoc();
1110 auto createVscaleMultiple =
1113 auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1116 "failed to resolve mask dims");
1119 int64_t fixedDimSize = vectorType.getDimSize(1);
1120 auto fixedDimOffsets = llvm::seq(fixedDimSize);
1123 auto transposeSource = transposeOp.getVector();
1125 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) ->
Value {
1126 return vector::ExtractOp::create(rewriter, loc, transposeSource, idx);
1133 ?
Value(createVscaleMultiple(vectorType.getDimSize(0)))
1139 Value sliceMask =
nullptr;
1140 if (!maskDims->empty()) {
1141 sliceMask = vector::CreateMaskOp::create(
1142 rewriter, loc, sliceType.clone(rewriter.
getI1Type()),
1146 Value initDest = isTensorOp(writeOp) ? writeOp.getBase() :
Value{};
1148 auto result = scf::ForOp::create(
1149 rewriter, loc, lb, ub, step, initLoopArgs,
1157 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) ->
Value {
1158 return vector::ExtractOp::create(
1159 b, loc, transposeSourceSlices[idx], iv);
1161 auto sliceVec = vector::FromElementsOp::create(b, loc, sliceType,
1166 loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
1167 auto newWriteOp = vector::TransferWriteOp::create(
1168 b, loc, sliceVec, dest, xferIndices,
1171 newWriteOp.getMaskMutable().assign(sliceMask);
1174 scf::YieldOp::create(b, loc,
1176 : newWriteOp.getResult());
1179 if (isTensorOp(writeOp))
1194 template <
typename OpTy>
1195 static void maybeAssignMask(
OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1197 if (!xferOp.getMask())
1200 if (xferOp.isBroadcastDim(0)) {
1203 newXferOp.getMaskMutable().assign(xferOp.getMask());
1207 if (xferOp.getMaskType().getRank() > 1) {
1214 auto newMask = vector::ExtractOp::create(b, loc, xferOp.getMask(), indices);
1215 newXferOp.getMaskMutable().assign(newMask);
1251 struct UnrollTransferReadConversion
1252 :
public VectorToSCFPattern<TransferReadOp> {
1253 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1258 setHasBoundedRewriteRecursion();
1264 TransferReadOp xferOp)
const {
1265 if (
auto insertOp = getInsertOp(xferOp))
1266 return insertOp.getDest();
1268 return vector::BroadcastOp::create(rewriter, loc, xferOp.getVectorType(),
1269 xferOp.getPadding());
1274 vector::InsertOp getInsertOp(TransferReadOp xferOp)
const {
1275 if (xferOp->hasOneUse()) {
1277 if (
auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1281 return vector::InsertOp();
1286 void getInsertionIndices(TransferReadOp xferOp,
1288 if (
auto insertOp = getInsertOp(xferOp)) {
1289 auto pos = insertOp.getMixedPosition();
1290 indices.append(pos.begin(), pos.end());
1296 LogicalResult matchAndRewrite(TransferReadOp xferOp,
1298 if (xferOp.getVectorType().getRank() <=
options.targetRank)
1300 xferOp,
"vector rank is less or equal to target rank");
1301 if (
failed(checkLowerTensors(xferOp, rewriter)))
1303 if (xferOp.getVectorType().getElementType() !=
1304 xferOp.getShapedType().getElementType())
1306 xferOp,
"not yet supported: element type mismatch");
1307 auto xferVecType = xferOp.getVectorType();
1308 if (xferVecType.getScalableDims()[0]) {
1310 xferOp,
"scalable dimensions cannot be unrolled at compile time");
1313 auto insertOp = getInsertOp(xferOp);
1314 auto vec = buildResultVector(rewriter, xferOp);
1315 auto vecType = dyn_cast<VectorType>(vec.getType());
1319 int64_t dimSize = xferVecType.getShape()[0];
1323 for (int64_t i = 0; i < dimSize; ++i) {
1328 vec = generateInBoundsCheck(
1329 rewriter, xferOp, iv, unpackedDim(xferOp),
TypeRange(vecType),
1338 getInsertionIndices(xferOp, insertionIndices);
1341 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1343 auto newXferOp = vector::TransferReadOp::create(
1344 b, loc, newXferVecType, xferOp.getBase(), xferIndices,
1346 xferOp.getPadding(),
Value(), inBoundsAttr);
1347 maybeAssignMask(b, xferOp, newXferOp, i);
1349 Value valToInser = newXferOp.getResult();
1350 if (newXferVecType.getRank() == 0) {
1353 valToInser = vector::ExtractOp::create(b, loc, valToInser,
1356 return vector::InsertOp::create(b, loc, valToInser, vec,
1404 struct UnrollTransferWriteConversion
1405 :
public VectorToSCFPattern<TransferWriteOp> {
1406 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1411 setHasBoundedRewriteRecursion();
1415 Value getDataVector(TransferWriteOp xferOp)
const {
1416 if (
auto extractOp = getExtractOp(xferOp))
1417 return extractOp.getVector();
1418 return xferOp.getVector();
1422 vector::ExtractOp getExtractOp(TransferWriteOp xferOp)
const {
1423 if (
auto *op = xferOp.getVector().getDefiningOp())
1424 return dyn_cast<vector::ExtractOp>(op);
1425 return vector::ExtractOp();
1430 void getExtractionIndices(TransferWriteOp xferOp,
1432 if (
auto extractOp = getExtractOp(xferOp)) {
1433 auto pos = extractOp.getMixedPosition();
1434 indices.append(pos.begin(), pos.end());
1440 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1442 VectorType inputVectorTy = xferOp.getVectorType();
1444 if (inputVectorTy.getRank() <=
options.targetRank)
1447 if (
failed(checkLowerTensors(xferOp, rewriter)))
1450 if (inputVectorTy.getElementType() !=
1451 xferOp.getShapedType().getElementType())
1454 auto vec = getDataVector(xferOp);
1455 if (inputVectorTy.getScalableDims()[0]) {
1460 int64_t dimSize = inputVectorTy.getShape()[0];
1461 Value source = xferOp.getBase();
1462 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
1466 for (int64_t i = 0; i < dimSize; ++i) {
1469 auto updatedSource = generateInBoundsCheck(
1470 rewriter, xferOp, iv, unpackedDim(xferOp),
1480 getExtractionIndices(xferOp, extractionIndices);
1484 vector::ExtractOp::create(b, loc, vec, extractionIndices);
1485 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1487 if (inputVectorTy.getRank() == 1) {
1491 xferVec = vector::BroadcastOp::create(
1494 xferVec = extracted;
1496 auto newXferOp = vector::TransferWriteOp::create(
1497 b, loc, sourceType, xferVec, source, xferIndices,
1501 maybeAssignMask(b, xferOp, newXferOp, i);
1503 return isTensorOp(xferOp) ? newXferOp->getResult(0) :
Value();
1507 return isTensorOp(xferOp) ? source :
Value();
1510 if (isTensorOp(xferOp))
1511 source = updatedSource;
1514 if (isTensorOp(xferOp))
1531 template <
typename OpTy>
1532 static std::optional<int64_t>
1535 auto indices = xferOp.getIndices();
1536 auto map = xferOp.getPermutationMap();
1537 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
1539 memrefIndices.append(indices.begin(), indices.end());
1540 assert(map.getNumResults() == 1 &&
1541 "Expected 1 permutation map result for 1D transfer");
1542 if (
auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1544 auto dim = expr.getPosition();
1546 bindDims(xferOp.getContext(), d0, d1);
1547 Value offset = memrefIndices[dim];
1548 memrefIndices[dim] =
1553 assert(xferOp.isBroadcastDim(0) &&
1554 "Expected AffineDimExpr or AffineConstantExpr");
1555 return std::nullopt;
1560 template <
typename OpTy>
1565 struct Strategy1d<TransferReadOp> {
1567 TransferReadOp xferOp,
Value iv,
1570 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1571 auto vec = loopState[0];
1575 auto nextVec = generateInBoundsCheck(
1576 b, xferOp, iv, dim,
TypeRange(xferOp.getVectorType()),
1579 Value val = memref::LoadOp::create(b, loc, xferOp.getBase(), indices);
1580 return vector::InsertOp::create(b, loc, val, vec, iv);
1584 scf::YieldOp::create(b, loc, nextVec);
1587 static Value initialLoopState(
OpBuilder &b, TransferReadOp xferOp) {
1590 return vector::BroadcastOp::create(b, loc, xferOp.getVectorType(),
1591 xferOp.getPadding());
1597 struct Strategy1d<TransferWriteOp> {
1599 TransferWriteOp xferOp,
Value iv,
1602 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1605 generateInBoundsCheck(
1608 auto val = vector::ExtractOp::create(b, loc, xferOp.getVector(), iv);
1609 memref::StoreOp::create(b, loc, val, xferOp.getBase(), indices);
1611 scf::YieldOp::create(b, loc);
1614 static Value initialLoopState(
OpBuilder &b, TransferWriteOp xferOp) {
1650 template <
typename OpTy>
1651 struct TransferOp1dConversion :
public VectorToSCFPattern<OpTy> {
1652 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1654 LogicalResult matchAndRewrite(OpTy xferOp,
1657 if (xferOp.getTransferRank() == 0)
1659 auto map = xferOp.getPermutationMap();
1660 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1664 if (xferOp.getVectorType().getRank() != 1)
1666 if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
1671 auto vecType = xferOp.getVectorType();
1675 if (vecType.isScalable()) {
1677 vector::VectorScaleOp::create(rewriter, loc, rewriter.
getIndexType());
1678 ub = arith::MulIOp::create(rewriter, loc, ub, vscale);
1681 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1687 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1700 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1701 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1704 patterns.add<lowering_n_d::PrepareTransferReadConversion,
1705 lowering_n_d::PrepareTransferWriteConversion,
1706 lowering_n_d::TransferOpConversion<TransferReadOp>,
1707 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1711 patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1714 if (
options.targetRank == 1) {
1715 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1716 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1719 patterns.add<lowering_n_d::DecomposePrintOpConversion>(
patterns.getContext(),
1725 struct ConvertVectorToSCFPass
1726 :
public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1727 ConvertVectorToSCFPass() =
default;
1729 this->fullUnroll =
options.unroll;
1730 this->targetRank =
options.targetRank;
1731 this->lowerTensors =
options.lowerTensors;
1732 this->lowerScalable =
options.lowerScalable;
1735 void runOnOperation()
override {
1738 options.targetRank = targetRank;
1739 options.lowerTensors = lowerTensors;
1740 options.lowerScalable = lowerScalable;
1745 lowerTransferPatterns);
1747 std::move(lowerTransferPatterns));
1757 std::unique_ptr<Pass>
1759 return std::make_unique<ConvertVectorToSCFPass>(
options);
MLIR_CRUNNERUTILS_EXPORT void printClose()
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static Operation * getAutomaticAllocationScope(Operation *op)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A trait of region holding operations that define a new scope for automatic allocations,...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Create a pass to convert a subset of vector ops to SCF.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...