15 #include <type_traits>
33 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
34 #include "mlir/Conversion/Passes.h.inc"
38 using vector::TransferReadOp;
39 using vector::TransferWriteOp;
44 static const char kPassLabel[] =
"__vector_to_scf_lowering__";
48 template <
typename OpTy>
60 template <
typename OpTy>
61 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
63 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
64 auto map = xferOp.getPermutationMap();
65 if (
auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
66 return expr.getPosition();
68 assert(xferOp.isBroadcastDim(0) &&
69 "Expected AffineDimExpr or AffineConstantExpr");
76 template <
typename OpTy>
79 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
80 auto map = xferOp.getPermutationMap();
81 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
91 template <
typename OpTy>
94 typename OpTy::Adaptor adaptor(xferOp);
96 auto dim = unpackedDim(xferOp);
97 auto prevIndices = adaptor.getIndices();
98 indices.append(prevIndices.begin(), prevIndices.end());
101 bool isBroadcast = !dim.has_value();
104 bindDims(xferOp.getContext(), d0, d1);
105 Value offset = adaptor.getIndices()[*dim];
114 assert(value &&
"Expected non-empty value");
115 b.
create<scf::YieldOp>(loc, value);
117 b.
create<scf::YieldOp>(loc);
127 template <
typename OpTy>
129 if (!xferOp.getMask())
131 if (xferOp.getMaskType().getRank() != 1)
133 if (xferOp.isBroadcastDim(0))
137 return b.
create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
164 template <
typename OpTy>
165 static Value generateInBoundsCheck(
170 bool hasRetVal = !resultTypes.empty();
174 bool isBroadcast = !dim;
177 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
181 bindDims(xferOp.getContext(), d0, d1);
182 Value base = xferOp.getIndices()[*dim];
185 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
190 if (
auto maskCond = generateMaskCheck(b, xferOp, iv)) {
192 cond = lb.create<arith::AndIOp>(cond, maskCond);
199 auto check = lb.create<scf::IfOp>(
203 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
207 if (outOfBoundsCase) {
208 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
210 b.
create<scf::YieldOp>(loc);
214 return hasRetVal ? check.getResult(0) :
Value();
218 return inBoundsCase(b, loc);
223 template <
typename OpTy>
224 static void generateInBoundsCheck(
228 generateInBoundsCheck(
232 inBoundsCase(b, loc);
238 outOfBoundsCase(b, loc);
244 static ArrayAttr dropFirstElem(
OpBuilder &b, ArrayAttr attr) {
252 template <
typename OpTy>
253 static void maybeApplyPassLabel(
OpBuilder &b, OpTy newXferOp,
254 unsigned targetRank) {
255 if (newXferOp.getVectorType().getRank() > targetRank)
260 template <
typename OpTy>
261 static bool isTensorOp(OpTy xferOp) {
262 if (isa<RankedTensorType>(xferOp.getShapedType())) {
263 if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
265 assert(xferOp->getNumResults() > 0);
275 struct BufferAllocs {
284 assert(scope &&
"Expected op to be inside automatic allocation scope");
289 template <
typename OpTy>
290 static BufferAllocs allocBuffers(
OpBuilder &b, OpTy xferOp) {
295 "AutomaticAllocationScope with >1 regions");
300 result.dataBuffer = b.
create<memref::AllocaOp>(loc, bufferType);
302 if (xferOp.getMask()) {
304 auto maskBuffer = b.
create<memref::AllocaOp>(loc, maskType);
306 b.
create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
307 result.maskBuffer = b.
create<memref::LoadOp>(loc, maskBuffer,
ValueRange());
318 auto vectorType = dyn_cast<VectorType>(type.getElementType());
321 if (vectorType.getScalableDims().front())
323 auto memrefShape = type.getShape();
325 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
326 newMemrefShape.push_back(vectorType.getDimSize(0));
333 template <
typename OpTy>
334 static Value getMaskBuffer(OpTy xferOp) {
335 assert(xferOp.getMask() &&
"Expected that transfer op has mask");
336 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
337 assert(loadOp &&
"Expected transfer op mask produced by LoadOp");
338 return loadOp.getMemRef();
342 template <
typename OpTy>
347 struct Strategy<TransferReadOp> {
350 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
351 assert(xferOp->hasOneUse() &&
"Expected exactly one use of TransferReadOp");
352 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
353 assert(storeOp &&
"Expected TransferReadOp result used by StoreOp");
365 return getStoreOp(xferOp).getMemRef();
369 static void getBufferIndices(TransferReadOp xferOp,
371 auto storeOp = getStoreOp(xferOp);
372 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
373 indices.append(prevIndices.begin(), prevIndices.end());
403 static TransferReadOp rewriteOp(
OpBuilder &b,
405 TransferReadOp xferOp,
Value buffer,
Value iv,
408 getBufferIndices(xferOp, storeIndices);
409 storeIndices.push_back(iv);
415 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
416 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
417 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
418 auto newXferOp = b.
create<vector::TransferReadOp>(
419 loc, vecType, xferOp.getSource(), xferIndices,
421 xferOp.getPadding(),
Value(), inBoundsAttr);
423 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
425 b.
create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
431 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferReadOp xferOp,
435 getBufferIndices(xferOp, storeIndices);
436 storeIndices.push_back(iv);
439 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
440 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
441 auto vec = b.
create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
442 b.
create<memref::StoreOp>(loc, vec, buffer, storeIndices);
450 rewriter.
eraseOp(getStoreOp(xferOp));
455 static Value initialLoopState(TransferReadOp xferOp) {
return Value(); }
460 struct Strategy<TransferWriteOp> {
469 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
470 assert(loadOp &&
"Expected transfer op vector produced by LoadOp");
471 return loadOp.getMemRef();
475 static void getBufferIndices(TransferWriteOp xferOp,
477 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
478 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
479 indices.append(prevIndices.begin(), prevIndices.end());
491 static TransferWriteOp rewriteOp(
OpBuilder &b,
493 TransferWriteOp xferOp,
Value buffer,
496 getBufferIndices(xferOp, loadIndices);
497 loadIndices.push_back(iv);
503 auto vec = b.
create<memref::LoadOp>(loc, buffer, loadIndices);
504 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
505 auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
506 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
507 auto newXferOp = b.
create<vector::TransferWriteOp>(
508 loc, type, vec, source, xferIndices,
512 maybeApplyPassLabel(b, newXferOp,
options.targetRank);
518 static Value handleOutOfBoundsDim(
OpBuilder &b, TransferWriteOp xferOp,
521 return isTensorOp(xferOp) ? loopState[0] :
Value();
527 if (isTensorOp(xferOp)) {
528 assert(forOp->getNumResults() == 1 &&
"Expected one for loop result");
529 rewriter.
replaceOp(xferOp, forOp->getResult(0));
536 static Value initialLoopState(TransferWriteOp xferOp) {
537 return isTensorOp(xferOp) ? xferOp.getSource() :
Value();
541 template <
typename OpTy>
544 if (xferOp->hasAttr(kPassLabel))
546 if (xferOp.getVectorType().getRank() <=
options.targetRank)
550 if (xferOp.getVectorType().getScalableDims().front())
552 if (isTensorOp(xferOp) && !
options.lowerTensors)
555 if (xferOp.getVectorType().getElementType() !=
556 xferOp.getShapedType().getElementType())
584 struct PrepareTransferReadConversion
585 :
public VectorToSCFPattern<TransferReadOp> {
586 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
593 auto buffers = allocBuffers(rewriter, xferOp);
594 auto *newXfer = rewriter.
clone(*xferOp.getOperation());
596 if (xferOp.getMask()) {
597 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
602 rewriter.
create<memref::StoreOp>(loc, newXfer->getResult(0),
633 struct PrepareTransferWriteConversion
634 :
public VectorToSCFPattern<TransferWriteOp> {
635 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
643 auto buffers = allocBuffers(rewriter, xferOp);
644 rewriter.
create<memref::StoreOp>(loc, xferOp.getVector(),
646 auto loadedVec = rewriter.
create<memref::LoadOp>(loc, buffers.dataBuffer);
648 xferOp.getVectorMutable().assign(loadedVec);
649 xferOp->setAttr(kPassLabel, rewriter.
getUnitAttr());
652 if (xferOp.getMask()) {
654 xferOp.getMaskMutable().assign(buffers.maskBuffer);
689 struct DecomposePrintOpConversion :
public VectorToSCFPattern<vector::PrintOp> {
690 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
696 VectorType vectorType = dyn_cast<VectorType>(
printOp.getPrintType());
706 if (vectorType.getRank() > 1 && vectorType.isScalable())
710 auto value =
printOp.getSource();
712 if (
auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
716 auto width = intTy.getWidth();
717 auto legalWidth = llvm::NextPowerOf2(
std::max(8u, width) - 1);
719 intTy.getSignedness());
721 auto signlessSourceVectorType =
722 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
723 auto signlessTargetVectorType =
724 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
725 auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
726 value = rewriter.
create<vector::BitCastOp>(loc, signlessSourceVectorType,
728 if (width == 1 || intTy.isUnsigned())
729 value = rewriter.
create<arith::ExtUIOp>(loc, signlessTargetVectorType,
732 value = rewriter.
create<arith::ExtSIOp>(loc, signlessTargetVectorType,
734 value = rewriter.
create<vector::BitCastOp>(loc, targetVectorType, value);
735 vectorType = targetVectorType;
738 auto scalableDimensions = vectorType.getScalableDims();
739 auto shape = vectorType.getShape();
740 constexpr int64_t singletonShape[] = {1};
741 if (vectorType.getRank() == 0)
742 shape = singletonShape;
744 if (vectorType.getRank() != 1) {
748 auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
749 std::multiplies<int64_t>());
750 auto flatVectorType =
752 value = rewriter.
create<vector::ShapeCastOp>(loc, flatVectorType, value);
755 vector::PrintOp firstClose;
757 for (
unsigned d = 0; d < shape.size(); d++) {
759 Value lowerBound = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
760 Value upperBound = rewriter.
create<arith::ConstantIndexOp>(loc, shape[d]);
761 Value step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
762 if (!scalableDimensions.empty() && scalableDimensions[d]) {
763 auto vscale = rewriter.
create<vector::VectorScaleOp>(
765 upperBound = rewriter.
create<arith::MulIOp>(loc, upperBound, vscale);
767 auto lastIndex = rewriter.
create<arith::SubIOp>(loc, upperBound, step);
770 rewriter.
create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
772 rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound, step);
774 loc, vector::PrintPunctuation::Close);
778 auto loopIdx = loop.getInductionVar();
779 loopIndices.push_back(loopIdx);
783 auto notLastIndex = rewriter.
create<arith::CmpIOp>(
784 loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
785 rewriter.
create<scf::IfOp>(loc, notLastIndex,
787 builder.create<vector::PrintOp>(
788 loc, vector::PrintPunctuation::Comma);
789 builder.create<scf::YieldOp>(loc);
798 auto currentStride = 1;
799 for (
int d = shape.size() - 1; d >= 0; d--) {
800 auto stride = rewriter.
create<arith::ConstantIndexOp>(loc, currentStride);
801 auto index = rewriter.
create<arith::MulIOp>(loc, stride, loopIndices[d]);
803 flatIndex = rewriter.
create<arith::AddIOp>(loc, flatIndex, index);
806 currentStride *= shape[d];
811 rewriter.
create<vector::ExtractElementOp>(loc, value, flatIndex);
812 rewriter.
create<vector::PrintOp>(loc, element,
813 vector::PrintPunctuation::NoPunctuation);
816 rewriter.
create<vector::PrintOp>(loc,
printOp.getPunctuation());
821 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
823 IntegerType::Signless);
856 template <
typename OpTy>
857 struct TransferOpConversion :
public VectorToSCFPattern<OpTy> {
858 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
863 this->setHasBoundedRewriteRecursion();
868 if (!xferOp->hasAttr(kPassLabel))
874 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
875 auto castedDataType = unpackOneDim(dataBufferType);
876 if (
failed(castedDataType))
879 auto castedDataBuffer =
880 locB.
create<vector::TypeCastOp>(*castedDataType, dataBuffer);
883 Value castedMaskBuffer;
884 if (xferOp.getMask()) {
885 auto maskBuffer = getMaskBuffer(xferOp);
886 auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
887 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
893 castedMaskBuffer = maskBuffer;
897 auto castedMaskType = *unpackOneDim(maskBufferType);
899 locB.
create<vector::TypeCastOp>(castedMaskType, maskBuffer);
904 auto lb = locB.
create<arith::ConstantIndexOp>(0);
905 auto ub = locB.
create<arith::ConstantIndexOp>(
906 castedDataType->getDimSize(castedDataType->getRank() - 1));
907 auto step = locB.
create<arith::ConstantIndexOp>(1);
910 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
913 auto result = locB.
create<scf::ForOp>(
916 Type stateType = loopState.empty() ?
Type() : loopState[0].getType();
918 auto result = generateInBoundsCheck(
919 b, xferOp, iv, unpackedDim(xferOp),
924 OpTy newXfer = Strategy<OpTy>::rewriteOp(
925 b, this->options, xferOp, castedDataBuffer, iv, loopState);
932 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
933 xferOp.getMaskType().getRank() > 1)) {
938 Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
941 if (!xferOp.isBroadcastDim(0))
942 loadIndices.push_back(iv);
944 auto mask = b.
create<memref::LoadOp>(loc, castedMaskBuffer,
947 newXfer.getMaskMutable().assign(mask);
951 return loopState.empty() ?
Value() : newXfer->getResult(0);
955 return Strategy<OpTy>::handleOutOfBoundsDim(
956 b, xferOp, castedDataBuffer, iv, loopState);
959 maybeYieldValue(b, loc, !loopState.empty(), result);
962 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
973 template <
typename OpTy>
974 static void maybeAssignMask(
OpBuilder &b, OpTy xferOp, OpTy newXferOp,
976 if (!xferOp.getMask())
979 if (xferOp.isBroadcastDim(0)) {
982 newXferOp.getMaskMutable().assign(xferOp.getMask());
986 if (xferOp.getMaskType().getRank() > 1) {
993 auto newMask = b.
create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
994 newXferOp.getMaskMutable().assign(newMask);
1030 struct UnrollTransferReadConversion
1031 :
public VectorToSCFPattern<TransferReadOp> {
1032 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1037 setHasBoundedRewriteRecursion();
1042 Value getResultVector(TransferReadOp xferOp,
1044 if (
auto insertOp = getInsertOp(xferOp))
1045 return insertOp.getDest();
1047 return rewriter.
create<vector::SplatOp>(loc, xferOp.getVectorType(),
1048 xferOp.getPadding());
1053 vector::InsertOp getInsertOp(TransferReadOp xferOp)
const {
1054 if (xferOp->hasOneUse()) {
1056 if (
auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1060 return vector::InsertOp();
1065 void getInsertionIndices(TransferReadOp xferOp,
1067 if (
auto insertOp = getInsertOp(xferOp)) {
1068 auto pos = insertOp.getMixedPosition();
1069 indices.append(pos.begin(), pos.end());
1077 if (xferOp.getVectorType().getRank() <=
options.targetRank)
1079 if (isTensorOp(xferOp) && !
options.lowerTensors)
1082 if (xferOp.getVectorType().getElementType() !=
1083 xferOp.getShapedType().getElementType())
1086 auto insertOp = getInsertOp(xferOp);
1087 auto vec = getResultVector(xferOp, rewriter);
1088 auto vecType = dyn_cast<VectorType>(vec.getType());
1089 auto xferVecType = xferOp.getVectorType();
1091 if (xferVecType.getScalableDims()[0]) {
1098 int64_t dimSize = xferVecType.getShape()[0];
1102 for (int64_t i = 0; i < dimSize; ++i) {
1103 Value iv = rewriter.
create<arith::ConstantIndexOp>(loc, i);
1105 vec = generateInBoundsCheck(
1106 rewriter, xferOp, iv, unpackedDim(xferOp),
TypeRange(vecType),
1115 getInsertionIndices(xferOp, insertionIndices);
1118 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1119 auto newXferOp = b.
create<vector::TransferReadOp>(
1120 loc, newXferVecType, xferOp.getSource(), xferIndices,
1122 xferOp.getPadding(),
Value(), inBoundsAttr);
1123 maybeAssignMask(b, xferOp, newXferOp, i);
1124 return b.
create<vector::InsertOp>(loc, newXferOp, vec,
1172 struct UnrollTransferWriteConversion
1173 :
public VectorToSCFPattern<TransferWriteOp> {
1174 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1179 setHasBoundedRewriteRecursion();
1183 Value getDataVector(TransferWriteOp xferOp)
const {
1184 if (
auto extractOp = getExtractOp(xferOp))
1185 return extractOp.getVector();
1186 return xferOp.getVector();
1190 vector::ExtractOp getExtractOp(TransferWriteOp xferOp)
const {
1191 if (
auto *op = xferOp.getVector().getDefiningOp())
1192 return dyn_cast<vector::ExtractOp>(op);
1193 return vector::ExtractOp();
1198 void getExtractionIndices(TransferWriteOp xferOp,
1200 if (
auto extractOp = getExtractOp(xferOp)) {
1201 auto pos = extractOp.getMixedPosition();
1202 indices.append(pos.begin(), pos.end());
1210 if (xferOp.getVectorType().getRank() <=
options.targetRank)
1212 if (isTensorOp(xferOp) && !
options.lowerTensors)
1215 if (xferOp.getVectorType().getElementType() !=
1216 xferOp.getShapedType().getElementType())
1219 auto vec = getDataVector(xferOp);
1220 auto xferVecType = xferOp.getVectorType();
1221 int64_t dimSize = xferVecType.getShape()[0];
1222 Value source = xferOp.getSource();
1223 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
1227 for (int64_t i = 0; i < dimSize; ++i) {
1228 Value iv = rewriter.
create<arith::ConstantIndexOp>(loc, i);
1230 auto updatedSource = generateInBoundsCheck(
1231 rewriter, xferOp, iv, unpackedDim(xferOp),
1241 getExtractionIndices(xferOp, extractionIndices);
1245 b.
create<vector::ExtractOp>(loc, vec, extractionIndices);
1246 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1247 auto newXferOp = b.
create<vector::TransferWriteOp>(
1248 loc, sourceType, extracted, source, xferIndices,
1252 maybeAssignMask(b, xferOp, newXferOp, i);
1254 return isTensorOp(xferOp) ? newXferOp->getResult(0) :
Value();
1258 return isTensorOp(xferOp) ? source :
Value();
1261 if (isTensorOp(xferOp))
1262 source = updatedSource;
1265 if (isTensorOp(xferOp))
1282 template <
typename OpTy>
1283 static std::optional<int64_t>
1286 auto indices = xferOp.getIndices();
1287 auto map = xferOp.getPermutationMap();
1288 assert(xferOp.getTransferRank() > 0 &&
"unexpected 0-d transfer");
1290 memrefIndices.append(indices.begin(), indices.end());
1291 assert(map.getNumResults() == 1 &&
1292 "Expected 1 permutation map result for 1D transfer");
1293 if (
auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
1295 auto dim = expr.getPosition();
1297 bindDims(xferOp.getContext(), d0, d1);
1298 Value offset = memrefIndices[dim];
1299 memrefIndices[dim] =
1304 assert(xferOp.isBroadcastDim(0) &&
1305 "Expected AffineDimExpr or AffineConstantExpr");
1306 return std::nullopt;
1311 template <
typename OpTy>
1316 struct Strategy1d<TransferReadOp> {
1318 TransferReadOp xferOp,
Value iv,
1321 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1322 auto vec = loopState[0];
1326 auto nextVec = generateInBoundsCheck(
1327 b, xferOp, iv, dim,
TypeRange(xferOp.getVectorType()),
1331 b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
1332 return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1336 b.
create<scf::YieldOp>(loc, nextVec);
1339 static Value initialLoopState(
OpBuilder &b, TransferReadOp xferOp) {
1342 return b.
create<vector::SplatOp>(loc, xferOp.getVectorType(),
1343 xferOp.getPadding());
1349 struct Strategy1d<TransferWriteOp> {
1351 TransferWriteOp xferOp,
Value iv,
1354 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1357 generateInBoundsCheck(
1361 b.
create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1362 b.
create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
1364 b.
create<scf::YieldOp>(loc);
1367 static Value initialLoopState(
OpBuilder &b, TransferWriteOp xferOp) {
1403 template <
typename OpTy>
1404 struct TransferOp1dConversion :
public VectorToSCFPattern<OpTy> {
1405 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1410 if (xferOp.getTransferRank() == 0)
1412 auto map = xferOp.getPermutationMap();
1413 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1417 if (xferOp.getVectorType().getRank() != 1)
1424 auto vecType = xferOp.getVectorType();
1425 auto lb = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1427 rewriter.
create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1428 if (vecType.isScalable()) {
1431 ub = rewriter.
create<arith::MulIOp>(loc, ub, vscale);
1433 auto step = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
1434 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1440 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1453 patterns.
add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1454 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1457 patterns.
add<lowering_n_d::PrepareTransferReadConversion,
1458 lowering_n_d::PrepareTransferWriteConversion,
1459 lowering_n_d::TransferOpConversion<TransferReadOp>,
1460 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1464 if (
options.targetRank == 1) {
1465 patterns.
add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1466 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1469 patterns.
add<lowering_n_d::DecomposePrintOpConversion>(patterns.
getContext(),
1475 struct ConvertVectorToSCFPass
1476 :
public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1477 ConvertVectorToSCFPass() =
default;
1479 this->fullUnroll =
options.unroll;
1480 this->targetRank =
options.targetRank;
1481 this->lowerTensors =
options.lowerTensors;
1484 void runOnOperation()
override {
1487 options.targetRank = targetRank;
1488 options.lowerTensors = lowerTensors;
1493 lowerTransferPatterns);
1495 std::move(lowerTransferPatterns));
1505 std::unique_ptr<Pass>
1507 return std::make_unique<ConvertVectorToSCFPass>(
options);
MLIR_CRUNNERUTILS_EXPORT void printClose()
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static Operation * getAutomaticAllocationScope(Operation *op)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
This class provides support for representing a failure result, or a valid value of type T.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool isLastMemrefDimUnitStride(MemRefType type)
Return "true" if the last dimension of the given type has a static unit stride.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Create a pass to convert a subset of vector ops to SCF.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...