15 #include <type_traits>
35 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
36 #include "mlir/Conversion/Passes.h.inc"
40 using vector::TransferReadOp;
41 using vector::TransferWriteOp;
46 static const char kPassLabel[] =
"__vector_to_scf_lowering__";
49 static bool isTensorOp(VectorTransferOpInterface xferOp) {
50 if (isa<RankedTensorType>(xferOp.getShapedType())) {
51 if (isa<vector::TransferWriteOp>(xferOp)) {
53 assert(xferOp->getNumResults() > 0);
62 template <
typename OpTy>
68 LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
70 if (isTensorOp(xferOp) && !
options.lowerTensors) {
72 xferOp,
"lowering tensor transfers is disabled");
83 template <
typename OpTy>
84 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
86 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
87 auto map = xferOp.getPermutationMap();
88 if (
auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
89 return expr.getPosition();
91 assert(xferOp.isBroadcastDim(0) &&
92 "Expected AffineDimExpr or AffineConstantExpr");
99 template <
typename OpTy>
102 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
103 auto map = xferOp.getPermutationMap();
104 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
114 template <
typename OpTy>
117 typename OpTy::Adaptor adaptor(xferOp);
119 auto dim = unpackedDim(xferOp);
120 auto prevIndices = adaptor.getIndices();
121 indices.append(prevIndices.begin(), prevIndices.end());
124 bool isBroadcast = !dim.has_value();
127 bindDims(xferOp.getContext(), d0, d1);
128 Value offset = adaptor.getIndices()[*dim];
137 assert(value &&
"Expected non-empty value");
138 b.
create<scf::YieldOp>(loc, value);
140 b.
create<scf::YieldOp>(loc);
150 template <
typename OpTy>
152 if (!xferOp.getMask())
154 if (xferOp.getMaskType().getRank() != 1)
156 if (xferOp.isBroadcastDim(0))
160 return b.
create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
187 template <
typename OpTy>
188 static Value generateInBoundsCheck(
193 bool hasRetVal = !resultTypes.empty();
197 bool isBroadcast = !dim;
200 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
204 bindDims(xferOp.getContext(), d0, d1);
205 Value base = xferOp.getIndices()[*dim];
208 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
213 if (
auto maskCond = generateMaskCheck(b, xferOp, iv)) {
215 cond = lb.create<arith::AndIOp>(cond, maskCond);
222 auto check = lb.create<scf::IfOp>(
226 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
230 if (outOfBoundsCase) {
231 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
233 b.
create<scf::YieldOp>(loc);
237 return hasRetVal ? check.getResult(0) :
Value();
241 return inBoundsCase(b, loc);
246 template <
typename OpTy>
247 static void generateInBoundsCheck(
251 generateInBoundsCheck(
255 inBoundsCase(b, loc);
261 outOfBoundsCase(b, loc);
267 static ArrayAttr dropFirstElem(
OpBuilder &b, ArrayAttr attr) {
275 template <
typename OpTy>
276 static void maybeApplyPassLabel(
OpBuilder &b, OpTy newXferOp,
277 unsigned targetRank) {
278 if (newXferOp.getVectorType().getRank() > targetRank)
285 struct BufferAllocs {
294 assert(scope &&
"Expected op to be inside automatic allocation scope");
299 template <
typename OpTy>
300 static BufferAllocs allocBuffers(
OpBuilder &b, OpTy xferOp) {
305 "AutomaticAllocationScope with >1 regions");
310 result.dataBuffer = b.
create<memref::AllocaOp>(loc, bufferType);
312 if (xferOp.getMask()) {
314 auto maskBuffer = b.
create<memref::AllocaOp>(loc, maskType);
316 b.
create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
317 result.maskBuffer = b.
create<memref::LoadOp>(loc, maskBuffer,
ValueRange());
327 static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
328 auto vectorType = dyn_cast<VectorType>(type.getElementType());
331 if (vectorType.getScalableDims().front())
333 auto memrefShape = type.getShape();
335 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
336 newMemrefShape.push_back(vectorType.getDimSize(0));
343 template <
typename OpTy>
344 static Value getMaskBuffer(OpTy xferOp) {
345 assert(xferOp.getMask() &&
"Expected that transfer op has mask");
346 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
347 assert(loadOp &&
"Expected transfer op mask produced by LoadOp");
348 return loadOp.getMemRef();
352 template <
typename OpTy>
357 struct Strategy<TransferReadOp> {
360 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
361 assert(xferOp->hasOneUse() &&
"Expected exactly one use of TransferReadOp");
362 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
363 assert(storeOp &&
"Expected TransferReadOp result used by StoreOp");
375 return getStoreOp(xferOp).getMemRef();
379 static void getBufferIndices(TransferReadOp xferOp,
381 auto storeOp = getStoreOp(xferOp);
382 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
383 indices.append(prevIndices.begin(), prevIndices.end());
413 static TransferReadOp rewriteOp(
OpBuilder &b,
415 TransferReadOp xferOp,
Value buffer,
Value iv,
418 getBufferIndices(xferOp, storeIndices);
419 storeIndices.push_back(iv);
425 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
426 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
427 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
428 auto newXferOp = b.
create<vector::TransferReadOp>(
429 loc, vecType, xferOp.getSource(), xferIndices,
431 xferOp.getPadding(),
Value(), inBoundsAttr);
433 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
435 b.
create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
441 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferReadOp xferOp,
445 getBufferIndices(xferOp, storeIndices);
446 storeIndices.push_back(iv);
449 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
450 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
451 auto vec = b.
create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
452 b.
create<memref::StoreOp>(loc, vec, buffer, storeIndices);
460 rewriter.
eraseOp(getStoreOp(xferOp));
465 static Value initialLoopState(TransferReadOp xferOp) {
return Value(); }
470 struct Strategy<TransferWriteOp> {
479 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
480 assert(loadOp &&
"Expected transfer op vector produced by LoadOp");
481 return loadOp.getMemRef();
485 static void getBufferIndices(TransferWriteOp xferOp,
487 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
488 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
489 indices.append(prevIndices.begin(), prevIndices.end());
501 static TransferWriteOp rewriteOp(
OpBuilder &b,
503 TransferWriteOp xferOp,
Value buffer,
506 getBufferIndices(xferOp, loadIndices);
507 loadIndices.push_back(iv);
513 auto vec = b.
create<memref::LoadOp>(loc, buffer, loadIndices);
514 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
515 auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
516 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
517 auto newXferOp = b.
create<vector::TransferWriteOp>(
518 loc, type, vec, source, xferIndices,
522 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
528 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferWriteOp xferOp,
531 return isTensorOp(xferOp) ? loopState[0] :
Value();
537 if (isTensorOp(xferOp)) {
538 assert(forOp->getNumResults() == 1 &&
"Expected one for loop result");
539 rewriter.
replaceOp(xferOp, forOp->getResult(0));
546 static Value initialLoopState(TransferWriteOp xferOp) {
547 return isTensorOp(xferOp) ? xferOp.getSource() :
Value();
551 template <
typename OpTy>
552 LogicalResult checkPrepareXferOp(OpTy xferOp,
554 if (xferOp->hasAttr(kPassLabel))
556 if (xferOp.getVectorType().getRank() <=
options.targetRank)
560 if (xferOp.getVectorType().getScalableDims().front())
562 if (isTensorOp(xferOp) && !
options.lowerTensors)
565 if (xferOp.getVectorType().getElementType() !=
566 xferOp.getShapedType().getElementType())
594 struct PrepareTransferReadConversion
595 :
public VectorToSCFPattern<TransferReadOp> {
596 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
598 LogicalResult matchAndRewrite(TransferReadOp xferOp,
600 if (checkPrepareXferOp(xferOp,
options).failed())
603 auto buffers = allocBuffers(rewriter, xferOp);
604 auto *newXfer = rewriter.
clone(*xferOp.getOperation());
606 if (xferOp.getMask()) {
607 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
612 rewriter.
create<memref::StoreOp>(loc, newXfer->getResult(0),
643 struct PrepareTransferWriteConversion
644 :
public VectorToSCFPattern<TransferWriteOp> {
645 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
647 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
649 if (checkPrepareXferOp(xferOp,
options).failed())
653 auto buffers = allocBuffers(rewriter, xferOp);
654 rewriter.
create<memref::StoreOp>(loc, xferOp.getVector(),
656 auto loadedVec = rewriter.
create<memref::LoadOp>(loc, buffers.dataBuffer);
658 xferOp.getVectorMutable().assign(loadedVec);
659 xferOp->setAttr(kPassLabel, rewriter.
getUnitAttr());
662 if (xferOp.getMask()) {
664 xferOp.getMaskMutable().assign(buffers.maskBuffer);
699 struct DecomposePrintOpConversion :
public VectorToSCFPattern<vector::PrintOp> {
700 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
701 LogicalResult matchAndRewrite(vector::PrintOp
printOp,
706 VectorType vectorType = dyn_cast<VectorType>(
printOp.getPrintType());
716 if (vectorType.getRank() > 1 && vectorType.isScalable())
720 auto value =
printOp.getSource();
722 if (
auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
726 auto width = intTy.getWidth();
727 auto legalWidth = llvm::NextPowerOf2(
std::max(8u, width) - 1);
729 intTy.getSignedness());
731 auto signlessSourceVectorType =
732 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
733 auto signlessTargetVectorType =
734 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
735 auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
736 value = rewriter.
create<vector::BitCastOp>(loc, signlessSourceVectorType,
738 if (value.
getType() != signlessTargetVectorType) {
739 if (width == 1 || intTy.isUnsigned())
740 value = rewriter.
create<arith::ExtUIOp>(loc, signlessTargetVectorType,
743 value = rewriter.
create<arith::ExtSIOp>(loc, signlessTargetVectorType,
746 value = rewriter.
create<vector::BitCastOp>(loc, targetVectorType, value);
747 vectorType = targetVectorType;
750 auto scalableDimensions = vectorType.getScalableDims();
751 auto shape = vectorType.getShape();
752 constexpr int64_t singletonShape[] = {1};
753 if (vectorType.getRank() == 0)
754 shape = singletonShape;
756 if (vectorType.getRank() != 1) {
760 auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
761 std::multiplies<int64_t>());
762 auto flatVectorType =
764 value = rewriter.
create<vector::ShapeCastOp>(loc, flatVectorType, value);
767 vector::PrintOp firstClose;
769 for (
unsigned d = 0; d < shape.size(); d++) {
771 Value lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
772 Value upperBound = rewriter.
create<arith::ConstantIndexOp>(loc, shape[d]);
773 Value step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
774 if (!scalableDimensions.empty() && scalableDimensions[d]) {
775 auto vscale = rewriter.
create<vector::VectorScaleOp>(
777 upperBound = rewriter.
create<arith::MulIOp>(loc, upperBound, vscale);
779 auto lastIndex = rewriter.
create<arith::SubIOp>(loc, upperBound, step);
782 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
784 rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
786 loc, vector::PrintPunctuation::Close);
790 auto loopIdx = loop.getInductionVar();
791 loopIndices.push_back(loopIdx);
795 auto notLastIndex = rewriter.
create<arith::CmpIOp>(
796 loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
797 rewriter.
create<scf::IfOp>(loc, notLastIndex,
799 builder.create<vector::PrintOp>(
800 loc, vector::PrintPunctuation::Comma);
801 builder.create<scf::YieldOp>(loc);
810 auto currentStride = 1;
811 for (
int d = shape.size() - 1; d >= 0; d--) {
812 auto stride = rewriter.
create<arith::ConstantIndexOp>(loc, currentStride);
813 auto index = rewriter.
create<arith::MulIOp>(loc, stride, loopIndices[d]);
815 flatIndex = rewriter.
create<arith::AddIOp>(loc, flatIndex, index);
818 currentStride *= shape[d];
823 rewriter.
create<vector::ExtractElementOp>(loc, value, flatIndex);
824 rewriter.
create<vector::PrintOp>(loc, element,
825 vector::PrintPunctuation::NoPunctuation);
828 rewriter.
create<vector::PrintOp>(loc,
printOp.getPunctuation());
833 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
835 IntegerType::Signless);
868 template <
typename OpTy>
869 struct TransferOpConversion :
public VectorToSCFPattern<OpTy> {
870 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
875 this->setHasBoundedRewriteRecursion();
878 static void getMaskBufferLoadIndices(OpTy xferOp,
Value castedMaskBuffer,
881 assert(xferOp.getMask() &&
"Expected transfer op to have mask");
887 Value maskBuffer = getMaskBuffer(xferOp);
890 if (
auto loadOp = dyn_cast<memref::LoadOp>(user)) {
892 loadIndices.append(prevIndices.begin(), prevIndices.end());
899 if (!xferOp.isBroadcastDim(0))
900 loadIndices.push_back(iv);
903 LogicalResult matchAndRewrite(OpTy xferOp,
905 if (!xferOp->hasAttr(kPassLabel))
911 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.
getType());
912 FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
913 if (failed(castedDataType))
916 auto castedDataBuffer =
917 locB.
create<vector::TypeCastOp>(*castedDataType, dataBuffer);
920 Value castedMaskBuffer;
921 if (xferOp.getMask()) {
922 Value maskBuffer = getMaskBuffer(xferOp);
923 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
929 castedMaskBuffer = maskBuffer;
933 auto maskBufferType = cast<MemRefType>(maskBuffer.
getType());
934 MemRefType castedMaskType = *unpackOneDim(maskBufferType);
936 locB.
create<vector::TypeCastOp>(castedMaskType, maskBuffer);
941 auto lb = locB.
create<arith::ConstantIndexOp>(0);
942 auto ub = locB.
create<arith::ConstantIndexOp>(
943 castedDataType->getDimSize(castedDataType->getRank() - 1));
944 auto step = locB.
create<arith::ConstantIndexOp>(1);
947 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
950 auto result = locB.
create<scf::ForOp>(
953 Type stateType = loopState.empty() ?
Type() : loopState[0].getType();
955 auto result = generateInBoundsCheck(
956 b, xferOp, iv, unpackedDim(xferOp),
961 OpTy newXfer = Strategy<OpTy>::rewriteOp(
962 b, this->options, xferOp, castedDataBuffer, iv, loopState);
968 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
969 xferOp.getMaskType().getRank() > 1)) {
974 getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
976 auto mask = b.
create<memref::LoadOp>(loc, castedMaskBuffer,
979 newXfer.getMaskMutable().assign(mask);
983 return loopState.empty() ?
Value() : newXfer->getResult(0);
987 return Strategy<OpTy>::handleOutOfBoundsDim(
988 b, xferOp, castedDataBuffer, iv, loopState);
991 maybeYieldValue(b, loc, !loopState.empty(), result);
994 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
1001 template <
typename VscaleConstantBuilder>
1002 static FailureOr<SmallVector<OpFoldResult>>
1003 getMaskDimSizes(
Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1006 if (
auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>()) {
1007 return llvm::map_to_vector(createMaskOp.getOperands(), [](
Value dimSize) {
1008 return OpFoldResult(dimSize);
1011 if (
auto constantMask = mask.
getDefiningOp<vector::ConstantMaskOp>()) {
1013 VectorType maskType = constantMask.getVectorType();
1015 return llvm::map_to_vector(
1016 constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1018 if (maskType.getScalableDims()[dimIdx++])
1019 return OpFoldResult(createVscaleMultiple(dimSize));
1020 return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1063 struct ScalableTransposeTransferWriteConversion
1064 : VectorToSCFPattern<vector::TransferWriteOp> {
1065 using VectorToSCFPattern::VectorToSCFPattern;
1067 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1069 if (failed(checkLowerTensors(writeOp, rewriter)))
1072 VectorType vectorType = writeOp.getVectorType();
1079 writeOp,
"expected vector of the form vector<[N]xMxty>");
1082 auto permutationMap = writeOp.getPermutationMap();
1083 if (!permutationMap.isIdentity()) {
1085 writeOp,
"non-identity permutations are unsupported (lower first)");
1091 if (!writeOp.isDimInBounds(0)) {
1093 writeOp,
"out-of-bounds dims are unsupported (use masking)");
1096 Value vector = writeOp.getVector();
1097 auto transposeOp = vector.
getDefiningOp<vector::TransposeOp>();
1103 auto loc = writeOp.getLoc();
1104 auto createVscaleMultiple =
1107 auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1108 if (failed(maskDims)) {
1110 "failed to resolve mask dims");
1113 int64_t fixedDimSize = vectorType.getDimSize(1);
1114 auto fixedDimOffsets = llvm::seq(fixedDimSize);
1117 auto transposeSource = transposeOp.getVector();
1119 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) ->
Value {
1120 return rewriter.
create<vector::ExtractOp>(loc, transposeSource, idx);
1124 auto lb = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1127 ?
Value(createVscaleMultiple(vectorType.getDimSize(0)))
1129 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
1133 Value sliceMask =
nullptr;
1134 if (!maskDims->empty()) {
1135 sliceMask = rewriter.
create<vector::CreateMaskOp>(
1136 loc, sliceType.clone(rewriter.
getI1Type()),
1140 Value initDest = isTensorOp(writeOp) ? writeOp.getSource() :
Value{};
1142 auto result = rewriter.
create<scf::ForOp>(
1143 loc, lb, ub, step, initLoopArgs,
1151 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) ->
Value {
1152 return b.create<vector::ExtractOp>(
1153 loc, transposeSourceSlices[idx], iv);
1155 auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
1160 loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front();
1161 auto newWriteOp = b.create<vector::TransferWriteOp>(
1162 loc, sliceVec, dest, xferIndices,
1165 newWriteOp.getMaskMutable().assign(sliceMask);
1168 b.create<scf::YieldOp>(loc, loopIterArgs.empty()
1170 : newWriteOp.getResult());
1173 if (isTensorOp(writeOp))
1188 template <
typename OpTy>
1189 static void maybeAssignMask(
OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1191 if (!xferOp.getMask())
1194 if (xferOp.isBroadcastDim(0)) {
1197 newXferOp.getMaskMutable().assign(xferOp.getMask());
1201 if (xferOp.getMaskType().getRank() > 1) {
1208 auto newMask = b.
create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
1209 newXferOp.getMaskMutable().assign(newMask);
1245 struct UnrollTransferReadConversion
1246 :
public VectorToSCFPattern<TransferReadOp> {
1247 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1252 setHasBoundedRewriteRecursion();
1258 TransferReadOp xferOp)
const {
1259 if (
auto insertOp = getInsertOp(xferOp))
1260 return insertOp.getDest();
1262 return rewriter.
create<vector::SplatOp>(loc, xferOp.getVectorType(),
1263 xferOp.getPadding());
1268 vector::InsertOp getInsertOp(TransferReadOp xferOp)
const {
1269 if (xferOp->hasOneUse()) {
1271 if (
auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1275 return vector::InsertOp();
1280 void getInsertionIndices(TransferReadOp xferOp,
1282 if (
auto insertOp = getInsertOp(xferOp)) {
1283 auto pos = insertOp.getMixedPosition();
1284 indices.append(pos.begin(), pos.end());
1290 LogicalResult matchAndRewrite(TransferReadOp xferOp,
1292 if (xferOp.getVectorType().getRank() <=
options.targetRank)
1294 xferOp,
"vector rank is less or equal to target rank");
1295 if (failed(checkLowerTensors(xferOp, rewriter)))
1298 if (xferOp.getVectorType().getElementType() !=
1299 xferOp.getShapedType().getElementType())
1301 xferOp,
"not yet supported: element type mismatch");
1302 auto xferVecType = xferOp.getVectorType();
1303 if (xferVecType.getScalableDims()[0]) {
1306 xferOp,
"scalable dimensions cannot be unrolled");
1309 auto insertOp = getInsertOp(xferOp);
1310 auto vec = buildResultVector(rewriter, xferOp);
1311 auto vecType = dyn_cast<VectorType>(vec.getType());
1315 int64_t dimSize = xferVecType.getShape()[0];
1319 for (int64_t i = 0; i < dimSize; ++i) {
1320 Value iv = rewriter.
create<arith::ConstantIndexOp>(loc, i);
1322 vec = generateInBoundsCheck(
1323 rewriter, xferOp, iv, unpackedDim(xferOp),
TypeRange(vecType),
1332 getInsertionIndices(xferOp, insertionIndices);
1335 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1336 auto newXferOp = b.
create<vector::TransferReadOp>(
1337 loc, newXferVecType, xferOp.getSource(), xferIndices,
1339 xferOp.getPadding(),
Value(), inBoundsAttr);
1340 maybeAssignMask(b, xferOp, newXferOp, i);
1341 return b.
create<vector::InsertOp>(loc, newXferOp, vec,
1389 struct UnrollTransferWriteConversion
1390 :
public VectorToSCFPattern<TransferWriteOp> {
1391 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1396 setHasBoundedRewriteRecursion();
1400 Value getDataVector(TransferWriteOp xferOp)
const {
1401 if (
auto extractOp = getExtractOp(xferOp))
1402 return extractOp.getVector();
1403 return xferOp.getVector();
1407 vector::ExtractOp getExtractOp(TransferWriteOp xferOp)
const {
1408 if (
auto *op = xferOp.getVector().getDefiningOp())
1409 return dyn_cast<vector::ExtractOp>(op);
1410 return vector::ExtractOp();
1415 void getExtractionIndices(TransferWriteOp xferOp,
1417 if (
auto extractOp = getExtractOp(xferOp)) {
1418 auto pos = extractOp.getMixedPosition();
1419 indices.append(pos.begin(), pos.end());
1425 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1427 VectorType inputVectorTy = xferOp.getVectorType();
1429 if (inputVectorTy.getRank() <=
options.targetRank)
1432 if (failed(checkLowerTensors(xferOp, rewriter)))
1435 if (inputVectorTy.getElementType() !=
1436 xferOp.getShapedType().getElementType())
1439 auto vec = getDataVector(xferOp);
1440 if (inputVectorTy.getScalableDims()[0]) {
1445 int64_t dimSize = inputVectorTy.getShape()[0];
1446 Value source = xferOp.getSource();
1447 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
1451 for (int64_t i = 0; i < dimSize; ++i) {
1452 Value iv = rewriter.
create<arith::ConstantIndexOp>(loc, i);
1454 auto updatedSource = generateInBoundsCheck(
1455 rewriter, xferOp, iv, unpackedDim(xferOp),
1465 getExtractionIndices(xferOp, extractionIndices);
1469 b.
create<vector::ExtractOp>(loc, vec, extractionIndices);
1470 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1472 if (inputVectorTy.getRank() == 1) {
1476 xferVec = b.
create<vector::BroadcastOp>(
1479 xferVec = extracted;
1481 auto newXferOp = b.
create<vector::TransferWriteOp>(
1482 loc, sourceType, xferVec, source, xferIndices,
1486 maybeAssignMask(b, xferOp, newXferOp, i);
1488 return isTensorOp(xferOp) ? newXferOp->getResult(0) :
Value();
1492 return isTensorOp(xferOp) ? source :
Value();
1495 if (isTensorOp(xferOp))
1496 source = updatedSource;
1499 if (isTensorOp(xferOp))
1516 template <
typename OpTy>
1517 static std::optional<int64_t>
1520 auto indices = xferOp.getIndices();
1521 auto map = xferOp.getPermutationMap();
1522 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
1524 memrefIndices.append(indices.begin(), indices.end());
1525 assert(map.getNumResults() == 1 &&
1526 "Expected 1 permutation map result for 1D transfer");
1527 if (
auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1529 auto dim = expr.getPosition();
1531 bindDims(xferOp.getContext(), d0, d1);
1532 Value offset = memrefIndices[dim];
1533 memrefIndices[dim] =
1538 assert(xferOp.isBroadcastDim(0) &&
1539 "Expected AffineDimExpr or AffineConstantExpr");
1540 return std::nullopt;
1545 template <
typename OpTy>
1550 struct Strategy1d<TransferReadOp> {
1552 TransferReadOp xferOp,
Value iv,
1555 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1556 auto vec = loopState[0];
1560 auto nextVec = generateInBoundsCheck(
1561 b, xferOp, iv, dim,
TypeRange(xferOp.getVectorType()),
1565 b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
1566 return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1570 b.
create<scf::YieldOp>(loc, nextVec);
1573 static Value initialLoopState(
OpBuilder &b, TransferReadOp xferOp) {
1576 return b.
create<vector::SplatOp>(loc, xferOp.getVectorType(),
1577 xferOp.getPadding());
1583 struct Strategy1d<TransferWriteOp> {
1585 TransferWriteOp xferOp,
Value iv,
1588 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1591 generateInBoundsCheck(
1595 b.
create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1596 b.
create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
1598 b.
create<scf::YieldOp>(loc);
1601 static Value initialLoopState(
OpBuilder &b, TransferWriteOp xferOp) {
1637 template <
typename OpTy>
1638 struct TransferOp1dConversion :
public VectorToSCFPattern<OpTy> {
1639 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1641 LogicalResult matchAndRewrite(OpTy xferOp,
1644 if (xferOp.getTransferRank() == 0)
1646 auto map = xferOp.getPermutationMap();
1647 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1651 if (xferOp.getVectorType().getRank() != 1)
1658 auto vecType = xferOp.getVectorType();
1659 auto lb = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1661 rewriter.
create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1662 if (vecType.isScalable()) {
1665 ub = rewriter.
create<arith::MulIOp>(loc, ub, vscale);
1667 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
1668 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1674 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1687 patterns.
add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1688 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1691 patterns.
add<lowering_n_d::PrepareTransferReadConversion,
1692 lowering_n_d::PrepareTransferWriteConversion,
1693 lowering_n_d::TransferOpConversion<TransferReadOp>,
1694 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1698 patterns.
add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1701 if (
options.targetRank == 1) {
1702 patterns.
add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1703 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1706 patterns.
add<lowering_n_d::DecomposePrintOpConversion>(patterns.
getContext(),
1712 struct ConvertVectorToSCFPass
1713 :
public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1714 ConvertVectorToSCFPass() =
default;
1716 this->fullUnroll =
options.unroll;
1717 this->targetRank =
options.targetRank;
1718 this->lowerTensors =
options.lowerTensors;
1719 this->lowerScalable =
options.lowerScalable;
1722 void runOnOperation()
override {
1725 options.targetRank = targetRank;
1726 options.lowerTensors = lowerTensors;
1727 options.lowerScalable = lowerScalable;
1732 lowerTransferPatterns);
1734 std::move(lowerTransferPatterns));
1744 std::unique_ptr<Pass>
1746 return std::make_unique<ConvertVectorToSCFPass>(
options);
MLIR_CRUNNERUTILS_EXPORT void printClose()
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static Operation * getAutomaticAllocationScope(Operation *op)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A trait of region holding operations that define a new scope for automatic allocations,...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Include the generated interface declarations.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Create a pass to convert a subset of vector ops to SCF.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...