15 #include <type_traits>
34 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
35 #include "mlir/Conversion/Passes.h.inc"
39 using vector::TransferReadOp;
40 using vector::TransferWriteOp;
45 static const char kPassLabel[] =
"__vector_to_scf_lowering__";
49 template <
typename OpTy>
61 template <
typename OpTy>
62 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
64 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
65 auto map = xferOp.getPermutationMap();
66 if (
auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
67 return expr.getPosition();
69 assert(xferOp.isBroadcastDim(0) &&
70 "Expected AffineDimExpr or AffineConstantExpr");
77 template <
typename OpTy>
80 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
81 auto map = xferOp.getPermutationMap();
82 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
92 template <
typename OpTy>
95 typename OpTy::Adaptor adaptor(xferOp);
97 auto dim = unpackedDim(xferOp);
98 auto prevIndices = adaptor.getIndices();
99 indices.append(prevIndices.begin(), prevIndices.end());
102 bool isBroadcast = !dim.has_value();
105 bindDims(xferOp.getContext(), d0, d1);
106 Value offset = adaptor.getIndices()[*dim];
115 assert(value &&
"Expected non-empty value");
116 b.
create<scf::YieldOp>(loc, value);
118 b.
create<scf::YieldOp>(loc);
128 template <
typename OpTy>
130 if (!xferOp.getMask())
132 if (xferOp.getMaskType().getRank() != 1)
134 if (xferOp.isBroadcastDim(0))
138 return b.
create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
165 template <
typename OpTy>
166 static Value generateInBoundsCheck(
171 bool hasRetVal = !resultTypes.empty();
175 bool isBroadcast = !dim;
178 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
182 bindDims(xferOp.getContext(), d0, d1);
183 Value base = xferOp.getIndices()[*dim];
186 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
191 if (
auto maskCond = generateMaskCheck(b, xferOp, iv)) {
193 cond = lb.create<arith::AndIOp>(cond, maskCond);
200 auto check = lb.create<scf::IfOp>(
204 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
208 if (outOfBoundsCase) {
209 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
211 b.
create<scf::YieldOp>(loc);
215 return hasRetVal ? check.getResult(0) :
Value();
219 return inBoundsCase(b, loc);
224 template <
typename OpTy>
225 static void generateInBoundsCheck(
229 generateInBoundsCheck(
233 inBoundsCase(b, loc);
239 outOfBoundsCase(b, loc);
245 static ArrayAttr dropFirstElem(
OpBuilder &b, ArrayAttr attr) {
253 template <
typename OpTy>
254 static void maybeApplyPassLabel(
OpBuilder &b, OpTy newXferOp,
255 unsigned targetRank) {
256 if (newXferOp.getVectorType().getRank() > targetRank)
261 template <
typename OpTy>
262 static bool isTensorOp(OpTy xferOp) {
263 if (isa<RankedTensorType>(xferOp.getShapedType())) {
264 if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
266 assert(xferOp->getNumResults() > 0);
276 struct BufferAllocs {
285 assert(scope &&
"Expected op to be inside automatic allocation scope");
290 template <
typename OpTy>
291 static BufferAllocs allocBuffers(
OpBuilder &b, OpTy xferOp) {
296 "AutomaticAllocationScope with >1 regions");
301 result.dataBuffer = b.
create<memref::AllocaOp>(loc, bufferType);
303 if (xferOp.getMask()) {
305 auto maskBuffer = b.
create<memref::AllocaOp>(loc, maskType);
307 b.
create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
308 result.maskBuffer = b.
create<memref::LoadOp>(loc, maskBuffer,
ValueRange());
319 auto vectorType = dyn_cast<VectorType>(type.getElementType());
322 if (vectorType.getScalableDims().front())
324 auto memrefShape = type.getShape();
326 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
327 newMemrefShape.push_back(vectorType.getDimSize(0));
334 template <
typename OpTy>
335 static Value getMaskBuffer(OpTy xferOp) {
336 assert(xferOp.getMask() &&
"Expected that transfer op has mask");
337 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
338 assert(loadOp &&
"Expected transfer op mask produced by LoadOp");
339 return loadOp.getMemRef();
343 template <
typename OpTy>
348 struct Strategy<TransferReadOp> {
351 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
352 assert(xferOp->hasOneUse() &&
"Expected exactly one use of TransferReadOp");
353 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
354 assert(storeOp &&
"Expected TransferReadOp result used by StoreOp");
366 return getStoreOp(xferOp).getMemRef();
370 static void getBufferIndices(TransferReadOp xferOp,
372 auto storeOp = getStoreOp(xferOp);
373 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
374 indices.append(prevIndices.begin(), prevIndices.end());
404 static TransferReadOp rewriteOp(
OpBuilder &b,
406 TransferReadOp xferOp,
Value buffer,
Value iv,
409 getBufferIndices(xferOp, storeIndices);
410 storeIndices.push_back(iv);
416 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
417 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
418 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
419 auto newXferOp = b.
create<vector::TransferReadOp>(
420 loc, vecType, xferOp.getSource(), xferIndices,
422 xferOp.getPadding(),
Value(), inBoundsAttr);
424 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
426 b.
create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
432 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferReadOp xferOp,
436 getBufferIndices(xferOp, storeIndices);
437 storeIndices.push_back(iv);
440 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
441 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
442 auto vec = b.
create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
443 b.
create<memref::StoreOp>(loc, vec, buffer, storeIndices);
451 rewriter.
eraseOp(getStoreOp(xferOp));
456 static Value initialLoopState(TransferReadOp xferOp) {
return Value(); }
461 struct Strategy<TransferWriteOp> {
470 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
471 assert(loadOp &&
"Expected transfer op vector produced by LoadOp");
472 return loadOp.getMemRef();
476 static void getBufferIndices(TransferWriteOp xferOp,
478 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
479 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
480 indices.append(prevIndices.begin(), prevIndices.end());
492 static TransferWriteOp rewriteOp(
OpBuilder &b,
494 TransferWriteOp xferOp,
Value buffer,
497 getBufferIndices(xferOp, loadIndices);
498 loadIndices.push_back(iv);
504 auto vec = b.
create<memref::LoadOp>(loc, buffer, loadIndices);
505 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
506 auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
507 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
508 auto newXferOp = b.
create<vector::TransferWriteOp>(
509 loc, type, vec, source, xferIndices,
513 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
519 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferWriteOp xferOp,
522 return isTensorOp(xferOp) ? loopState[0] :
Value();
528 if (isTensorOp(xferOp)) {
529 assert(forOp->getNumResults() == 1 &&
"Expected one for loop result");
530 rewriter.
replaceOp(xferOp, forOp->getResult(0));
537 static Value initialLoopState(TransferWriteOp xferOp) {
538 return isTensorOp(xferOp) ? xferOp.getSource() :
Value();
542 template <
typename OpTy>
545 if (xferOp->hasAttr(kPassLabel))
547 if (xferOp.getVectorType().getRank() <=
options.targetRank)
551 if (xferOp.getVectorType().getScalableDims().front())
553 if (isTensorOp(xferOp) && !
options.lowerTensors)
556 if (xferOp.getVectorType().getElementType() !=
557 xferOp.getShapedType().getElementType())
585 struct PrepareTransferReadConversion
586 :
public VectorToSCFPattern<TransferReadOp> {
587 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
594 auto buffers = allocBuffers(rewriter, xferOp);
595 auto *newXfer = rewriter.
clone(*xferOp.getOperation());
597 if (xferOp.getMask()) {
598 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
603 rewriter.
create<memref::StoreOp>(loc, newXfer->getResult(0),
634 struct PrepareTransferWriteConversion
635 :
public VectorToSCFPattern<TransferWriteOp> {
636 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
644 auto buffers = allocBuffers(rewriter, xferOp);
645 rewriter.
create<memref::StoreOp>(loc, xferOp.getVector(),
647 auto loadedVec = rewriter.
create<memref::LoadOp>(loc, buffers.dataBuffer);
649 xferOp.getVectorMutable().assign(loadedVec);
650 xferOp->setAttr(kPassLabel, rewriter.
getUnitAttr());
653 if (xferOp.getMask()) {
655 xferOp.getMaskMutable().assign(buffers.maskBuffer);
690 struct DecomposePrintOpConversion :
public VectorToSCFPattern<vector::PrintOp> {
691 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
697 VectorType vectorType = dyn_cast<VectorType>(
printOp.getPrintType());
707 if (vectorType.getRank() > 1 && vectorType.isScalable())
711 auto value =
printOp.getSource();
713 if (
auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
717 auto width = intTy.getWidth();
718 auto legalWidth = llvm::NextPowerOf2(
std::max(8u, width) - 1);
720 intTy.getSignedness());
722 auto signlessSourceVectorType =
723 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
724 auto signlessTargetVectorType =
725 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
726 auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
727 value = rewriter.
create<vector::BitCastOp>(loc, signlessSourceVectorType,
729 if (value.
getType() != signlessTargetVectorType) {
730 if (width == 1 || intTy.isUnsigned())
731 value = rewriter.
create<arith::ExtUIOp>(loc, signlessTargetVectorType,
734 value = rewriter.
create<arith::ExtSIOp>(loc, signlessTargetVectorType,
737 value = rewriter.
create<vector::BitCastOp>(loc, targetVectorType, value);
738 vectorType = targetVectorType;
741 auto scalableDimensions = vectorType.getScalableDims();
742 auto shape = vectorType.getShape();
743 constexpr int64_t singletonShape[] = {1};
744 if (vectorType.getRank() == 0)
745 shape = singletonShape;
747 if (vectorType.getRank() != 1) {
751 auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
752 std::multiplies<int64_t>());
753 auto flatVectorType =
755 value = rewriter.
create<vector::ShapeCastOp>(loc, flatVectorType, value);
758 vector::PrintOp firstClose;
760 for (
unsigned d = 0; d < shape.size(); d++) {
762 Value lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
763 Value upperBound = rewriter.
create<arith::ConstantIndexOp>(loc, shape[d]);
764 Value step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
765 if (!scalableDimensions.empty() && scalableDimensions[d]) {
766 auto vscale = rewriter.
create<vector::VectorScaleOp>(
768 upperBound = rewriter.
create<arith::MulIOp>(loc, upperBound, vscale);
770 auto lastIndex = rewriter.
create<arith::SubIOp>(loc, upperBound, step);
773 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
775 rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
777 loc, vector::PrintPunctuation::Close);
781 auto loopIdx = loop.getInductionVar();
782 loopIndices.push_back(loopIdx);
786 auto notLastIndex = rewriter.
create<arith::CmpIOp>(
787 loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
788 rewriter.
create<scf::IfOp>(loc, notLastIndex,
790 builder.create<vector::PrintOp>(
791 loc, vector::PrintPunctuation::Comma);
792 builder.create<scf::YieldOp>(loc);
801 auto currentStride = 1;
802 for (
int d = shape.size() - 1; d >= 0; d--) {
803 auto stride = rewriter.
create<arith::ConstantIndexOp>(loc, currentStride);
804 auto index = rewriter.
create<arith::MulIOp>(loc, stride, loopIndices[d]);
806 flatIndex = rewriter.
create<arith::AddIOp>(loc, flatIndex, index);
809 currentStride *= shape[d];
814 rewriter.
create<vector::ExtractElementOp>(loc, value, flatIndex);
815 rewriter.
create<vector::PrintOp>(loc, element,
816 vector::PrintPunctuation::NoPunctuation);
819 rewriter.
create<vector::PrintOp>(loc,
printOp.getPunctuation());
824 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
826 IntegerType::Signless);
859 template <
typename OpTy>
860 struct TransferOpConversion :
public VectorToSCFPattern<OpTy> {
861 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
866 this->setHasBoundedRewriteRecursion();
869 static void getMaskBufferLoadIndices(OpTy xferOp,
Value castedMaskBuffer,
872 assert(xferOp.getMask() &&
"Expected transfer op to have mask");
878 Value maskBuffer = getMaskBuffer(xferOp);
881 if (
auto loadOp = dyn_cast<memref::LoadOp>(user)) {
883 loadIndices.append(prevIndices.begin(), prevIndices.end());
890 if (!xferOp.isBroadcastDim(0))
891 loadIndices.push_back(iv);
896 if (!xferOp->hasAttr(kPassLabel))
902 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.
getType());
904 if (
failed(castedDataType))
907 auto castedDataBuffer =
908 locB.
create<vector::TypeCastOp>(*castedDataType, dataBuffer);
911 Value castedMaskBuffer;
912 if (xferOp.getMask()) {
913 Value maskBuffer = getMaskBuffer(xferOp);
914 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
920 castedMaskBuffer = maskBuffer;
924 auto maskBufferType = cast<MemRefType>(maskBuffer.
getType());
925 MemRefType castedMaskType = *unpackOneDim(maskBufferType);
927 locB.
create<vector::TypeCastOp>(castedMaskType, maskBuffer);
932 auto lb = locB.
create<arith::ConstantIndexOp>(0);
933 auto ub = locB.
create<arith::ConstantIndexOp>(
934 castedDataType->getDimSize(castedDataType->getRank() - 1));
935 auto step = locB.
create<arith::ConstantIndexOp>(1);
938 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
941 auto result = locB.
create<scf::ForOp>(
944 Type stateType = loopState.empty() ?
Type() : loopState[0].getType();
946 auto result = generateInBoundsCheck(
947 b, xferOp, iv, unpackedDim(xferOp),
952 OpTy newXfer = Strategy<OpTy>::rewriteOp(
953 b, this->options, xferOp, castedDataBuffer, iv, loopState);
959 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
960 xferOp.getMaskType().getRank() > 1)) {
965 getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
967 auto mask = b.
create<memref::LoadOp>(loc, castedMaskBuffer,
970 newXfer.getMaskMutable().assign(mask);
974 return loopState.empty() ?
Value() : newXfer->getResult(0);
978 return Strategy<OpTy>::handleOutOfBoundsDim(
979 b, xferOp, castedDataBuffer, iv, loopState);
982 maybeYieldValue(b, loc, !loopState.empty(), result);
985 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
996 template <
typename OpTy>
997 static void maybeAssignMask(
OpBuilder &b, OpTy xferOp, OpTy newXferOp,
999 if (!xferOp.getMask())
1002 if (xferOp.isBroadcastDim(0)) {
1005 newXferOp.getMaskMutable().assign(xferOp.getMask());
1009 if (xferOp.getMaskType().getRank() > 1) {
1016 auto newMask = b.
create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
1017 newXferOp.getMaskMutable().assign(newMask);
1053 struct UnrollTransferReadConversion
1054 :
public VectorToSCFPattern<TransferReadOp> {
1055 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1060 setHasBoundedRewriteRecursion();
1066 TransferReadOp xferOp)
const {
1067 if (
auto insertOp = getInsertOp(xferOp))
1068 return insertOp.getDest();
1070 return rewriter.
create<vector::SplatOp>(loc, xferOp.getVectorType(),
1071 xferOp.getPadding());
1076 vector::InsertOp getInsertOp(TransferReadOp xferOp)
const {
1077 if (xferOp->hasOneUse()) {
1079 if (
auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1083 return vector::InsertOp();
1088 void getInsertionIndices(TransferReadOp xferOp,
1090 if (
auto insertOp = getInsertOp(xferOp)) {
1091 auto pos = insertOp.getMixedPosition();
1092 indices.append(pos.begin(), pos.end());
1100 if (xferOp.getVectorType().getRank() <=
options.targetRank)
1102 xferOp,
"vector rank is less or equal to target rank");
1103 if (isTensorOp(xferOp) && !
options.lowerTensors)
1105 xferOp,
"transfers operating on tensors are excluded");
1107 if (xferOp.getVectorType().getElementType() !=
1108 xferOp.getShapedType().getElementType())
1110 xferOp,
"not yet supported: element type mismatch");
1111 auto xferVecType = xferOp.getVectorType();
1112 if (xferVecType.getScalableDims()[0]) {
1115 xferOp,
"scalable dimensions cannot be unrolled");
1118 auto insertOp = getInsertOp(xferOp);
1119 auto vec = buildResultVector(rewriter, xferOp);
1120 auto vecType = dyn_cast<VectorType>(vec.getType());
1124 int64_t dimSize = xferVecType.getShape()[0];
1128 for (int64_t i = 0; i < dimSize; ++i) {
1129 Value iv = rewriter.
create<arith::ConstantIndexOp>(loc, i);
1131 vec = generateInBoundsCheck(
1132 rewriter, xferOp, iv, unpackedDim(xferOp),
TypeRange(vecType),
1141 getInsertionIndices(xferOp, insertionIndices);
1144 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1145 auto newXferOp = b.
create<vector::TransferReadOp>(
1146 loc, newXferVecType, xferOp.getSource(), xferIndices,
1148 xferOp.getPadding(),
Value(), inBoundsAttr);
1149 maybeAssignMask(b, xferOp, newXferOp, i);
1150 return b.
create<vector::InsertOp>(loc, newXferOp, vec,
1198 struct UnrollTransferWriteConversion
1199 :
public VectorToSCFPattern<TransferWriteOp> {
1200 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1205 setHasBoundedRewriteRecursion();
1209 Value getDataVector(TransferWriteOp xferOp)
const {
1210 if (
auto extractOp = getExtractOp(xferOp))
1211 return extractOp.getVector();
1212 return xferOp.getVector();
1216 vector::ExtractOp getExtractOp(TransferWriteOp xferOp)
const {
1217 if (
auto *op = xferOp.getVector().getDefiningOp())
1218 return dyn_cast<vector::ExtractOp>(op);
1219 return vector::ExtractOp();
1224 void getExtractionIndices(TransferWriteOp xferOp,
1226 if (
auto extractOp = getExtractOp(xferOp)) {
1227 auto pos = extractOp.getMixedPosition();
1228 indices.append(pos.begin(), pos.end());
1236 VectorType inputVectorTy = xferOp.getVectorType();
1238 if (inputVectorTy.getRank() <=
options.targetRank)
1241 if (isTensorOp(xferOp) && !
options.lowerTensors)
1244 if (inputVectorTy.getElementType() !=
1245 xferOp.getShapedType().getElementType())
1248 auto vec = getDataVector(xferOp);
1249 if (inputVectorTy.getScalableDims()[0]) {
1254 int64_t dimSize = inputVectorTy.getShape()[0];
1255 Value source = xferOp.getSource();
1256 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
1260 for (int64_t i = 0; i < dimSize; ++i) {
1261 Value iv = rewriter.
create<arith::ConstantIndexOp>(loc, i);
1263 auto updatedSource = generateInBoundsCheck(
1264 rewriter, xferOp, iv, unpackedDim(xferOp),
1274 getExtractionIndices(xferOp, extractionIndices);
1278 b.
create<vector::ExtractOp>(loc, vec, extractionIndices);
1279 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1281 if (inputVectorTy.getRank() == 1) {
1285 xferVec = b.
create<vector::BroadcastOp>(
1288 xferVec = extracted;
1290 auto newXferOp = b.
create<vector::TransferWriteOp>(
1291 loc, sourceType, xferVec, source, xferIndices,
1295 maybeAssignMask(b, xferOp, newXferOp, i);
1297 return isTensorOp(xferOp) ? newXferOp->getResult(0) :
Value();
1301 return isTensorOp(xferOp) ? source :
Value();
1304 if (isTensorOp(xferOp))
1305 source = updatedSource;
1308 if (isTensorOp(xferOp))
1325 template <
typename OpTy>
1326 static std::optional<int64_t>
1329 auto indices = xferOp.getIndices();
1330 auto map = xferOp.getPermutationMap();
1331 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
1333 memrefIndices.append(indices.begin(), indices.end());
1334 assert(map.getNumResults() == 1 &&
1335 "Expected 1 permutation map result for 1D transfer");
1336 if (
auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1338 auto dim = expr.getPosition();
1340 bindDims(xferOp.getContext(), d0, d1);
1341 Value offset = memrefIndices[dim];
1342 memrefIndices[dim] =
1347 assert(xferOp.isBroadcastDim(0) &&
1348 "Expected AffineDimExpr or AffineConstantExpr");
1349 return std::nullopt;
1354 template <
typename OpTy>
1359 struct Strategy1d<TransferReadOp> {
1361 TransferReadOp xferOp,
Value iv,
1364 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1365 auto vec = loopState[0];
1369 auto nextVec = generateInBoundsCheck(
1370 b, xferOp, iv, dim,
TypeRange(xferOp.getVectorType()),
1374 b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
1375 return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1379 b.
create<scf::YieldOp>(loc, nextVec);
1382 static Value initialLoopState(
OpBuilder &b, TransferReadOp xferOp) {
1385 return b.
create<vector::SplatOp>(loc, xferOp.getVectorType(),
1386 xferOp.getPadding());
1392 struct Strategy1d<TransferWriteOp> {
1394 TransferWriteOp xferOp,
Value iv,
1397 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1400 generateInBoundsCheck(
1404 b.
create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1405 b.
create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
1407 b.
create<scf::YieldOp>(loc);
1410 static Value initialLoopState(
OpBuilder &b, TransferWriteOp xferOp) {
1446 template <
typename OpTy>
1447 struct TransferOp1dConversion :
public VectorToSCFPattern<OpTy> {
1448 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1453 if (xferOp.getTransferRank() == 0)
1455 auto map = xferOp.getPermutationMap();
1456 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1460 if (xferOp.getVectorType().getRank() != 1)
1467 auto vecType = xferOp.getVectorType();
1468 auto lb = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1470 rewriter.
create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1471 if (vecType.isScalable()) {
1474 ub = rewriter.
create<arith::MulIOp>(loc, ub, vscale);
1476 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
1477 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1483 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1496 patterns.
add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1497 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1500 patterns.
add<lowering_n_d::PrepareTransferReadConversion,
1501 lowering_n_d::PrepareTransferWriteConversion,
1502 lowering_n_d::TransferOpConversion<TransferReadOp>,
1503 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1507 if (
options.targetRank == 1) {
1508 patterns.
add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1509 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1512 patterns.
add<lowering_n_d::DecomposePrintOpConversion>(patterns.
getContext(),
1518 struct ConvertVectorToSCFPass
1519 :
public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1520 ConvertVectorToSCFPass() =
default;
1522 this->fullUnroll =
options.unroll;
1523 this->targetRank =
options.targetRank;
1524 this->lowerTensors =
options.lowerTensors;
1527 void runOnOperation()
override {
1530 options.targetRank = targetRank;
1531 options.lowerTensors = lowerTensors;
1536 lowerTransferPatterns);
1538 std::move(lowerTransferPatterns));
1548 std::unique_ptr<Pass>
1550 return std::make_unique<ConvertVectorToSCFPass>(
options);
MLIR_CRUNNERUTILS_EXPORT void printClose()
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static Operation * getAutomaticAllocationScope(Operation *op)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
This class provides support for representing a failure result, or a valid value of type T.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A trait of region holding operations that define a new scope for automatic allocations,...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Create a pass to convert a subset of vector ops to SCF.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...