290 assert(scope &&
"Expected op to be inside automatic allocation scope");
295template <
typename OpTy>
296static BufferAllocs allocBuffers(
OpBuilder &
b, OpTy xferOp) {
299 Operation *scope = getAutomaticAllocationScope(xferOp);
301 "AutomaticAllocationScope with >1 regions");
305 auto bufferType = MemRefType::get({}, xferOp.getVectorType());
306 result.dataBuffer = memref::AllocaOp::create(
b, loc, bufferType);
308 if (xferOp.getMask()) {
309 auto maskType = MemRefType::get({}, xferOp.getMask().
getType());
310 auto maskBuffer = memref::AllocaOp::create(
b, loc, maskType);
311 b.setInsertionPoint(xferOp);
312 memref::StoreOp::create(
b, loc, xferOp.getMask(), maskBuffer);
314 memref::LoadOp::create(
b, loc, maskBuffer,
ValueRange());
324static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
325 auto vectorType = dyn_cast<VectorType>(type.getElementType());
328 if (vectorType.getScalableDims().front())
330 auto memrefShape = type.getShape();
332 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
333 newMemrefShape.push_back(vectorType.getDimSize(0));
334 return MemRefType::get(newMemrefShape,
340template <
typename OpTy>
341static Value getMaskBuffer(OpTy xferOp) {
342 assert(xferOp.getMask() &&
"Expected that transfer op has mask");
343 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
344 assert(loadOp &&
"Expected transfer op mask produced by LoadOp");
345 return loadOp.getMemRef();
349template <
typename OpTy>
354struct Strategy<TransferReadOp> {
357 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
358 assert(xferOp->hasOneUse() &&
"Expected exactly one use of TransferReadOp");
359 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
360 assert(storeOp &&
"Expected TransferReadOp result used by StoreOp");
371 static Value getBuffer(TransferReadOp xferOp) {
372 return getStoreOp(xferOp).getMemRef();
376 static void getBufferIndices(TransferReadOp xferOp,
378 auto storeOp = getStoreOp(xferOp);
379 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
380 indices.append(prevIndices.begin(), prevIndices.end());
412 TransferReadOp xferOp,
Value buffer,
Value iv,
415 getBufferIndices(xferOp, storeIndices);
416 storeIndices.push_back(iv);
422 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
423 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
424 auto inBoundsAttr = dropFirstElem(
b, xferOp.getInBoundsAttr());
425 auto newXferOp = vector::TransferReadOp::create(
426 b, loc, vecType, xferOp.getBase(), xferIndices,
427 AffineMapAttr::get(unpackedPermutationMap(
b, xferOp)),
428 xferOp.getPadding(),
Value(), inBoundsAttr);
430 maybeApplyPassLabel(
b, newXferOp,
options.targetRank);
432 memref::StoreOp::create(
b, loc, newXferOp.getVector(), buffer,
443 getBufferIndices(xferOp, storeIndices);
444 storeIndices.push_back(iv);
447 auto bufferType = dyn_cast<ShapedType>(buffer.
getType());
448 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
450 vector::BroadcastOp::create(
b, loc, vecType, xferOp.getPadding());
451 memref::StoreOp::create(
b, loc, vec, buffer, storeIndices);
459 rewriter.
eraseOp(getStoreOp(xferOp));
464 static Value initialLoopState(TransferReadOp xferOp) {
return Value(); }
469struct Strategy<TransferWriteOp> {
477 static Value getBuffer(TransferWriteOp xferOp) {
478 auto loadOp = xferOp.getVector().
getDefiningOp<memref::LoadOp>();
479 assert(loadOp &&
"Expected transfer op vector produced by LoadOp");
480 return loadOp.getMemRef();
484 static void getBufferIndices(TransferWriteOp xferOp,
486 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
487 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
488 indices.append(prevIndices.begin(), prevIndices.end());
500 static TransferWriteOp rewriteOp(
OpBuilder &
b,
502 TransferWriteOp xferOp,
Value buffer,
505 getBufferIndices(xferOp, loadIndices);
506 loadIndices.push_back(iv);
512 auto vec = memref::LoadOp::create(
b, loc, buffer, loadIndices);
513 auto inBoundsAttr = dropFirstElem(
b, xferOp.getInBoundsAttr());
514 auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
515 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
516 auto newXferOp = vector::TransferWriteOp::create(
517 b, loc, type, vec, source, xferIndices,
518 AffineMapAttr::get(unpackedPermutationMap(
b, xferOp)),
Value(),
521 maybeApplyPassLabel(
b, newXferOp,
options.targetRank);
527 static Value handleOutOfBoundsDim(
OpBuilder &
b, TransferWriteOp xferOp,
530 return isTensorOp(xferOp) ? loopState[0] :
Value();
536 if (isTensorOp(xferOp)) {
537 assert(forOp->getNumResults() == 1 &&
"Expected one for loop result");
538 rewriter.
replaceOp(xferOp, forOp->getResult(0));
545 static Value initialLoopState(TransferWriteOp xferOp) {
546 return isTensorOp(xferOp) ? xferOp.getBase() :
Value();
550template <
typename OpTy>
551static LogicalResult checkPrepareXferOp(OpTy xferOp,
PatternRewriter &rewriter,
553 if (xferOp->hasAttr(kPassLabel))
555 xferOp,
"kPassLabel is present (vector-to-scf lowering in progress)");
556 if (xferOp.getVectorType().getRank() <=
options.targetRank)
558 xferOp,
"xferOp vector rank <= transformation target rank");
559 if (xferOp.getVectorType().getScalableDims().front())
561 xferOp,
"Unpacking of the leading dimension into the memref is not yet "
562 "supported for scalable dims");
563 if (isTensorOp(xferOp) && !
options.lowerTensors)
565 xferOp,
"Unpacking for tensors has been disabled.");
566 if (xferOp.getVectorType().getElementType() !=
567 xferOp.getShapedType().getElementType())
569 xferOp,
"Mismatching source and destination element types.");
597struct PrepareTransferReadConversion
598 :
public VectorToSCFPattern<TransferReadOp> {
599 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
601 LogicalResult matchAndRewrite(TransferReadOp xferOp,
603 if (checkPrepareXferOp(xferOp, rewriter,
options).failed())
605 xferOp,
"checkPrepareXferOp conditions not met!");
607 auto buffers = allocBuffers(rewriter, xferOp);
608 auto *newXfer = rewriter.
clone(*xferOp.getOperation());
610 if (xferOp.getMask()) {
611 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
616 memref::StoreOp::create(rewriter, loc, newXfer->getResult(0),
647struct PrepareTransferWriteConversion
648 :
public VectorToSCFPattern<TransferWriteOp> {
649 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
651 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
653 if (checkPrepareXferOp(xferOp, rewriter,
options).failed())
655 xferOp,
"checkPrepareXferOp conditions not met!");
658 auto buffers = allocBuffers(rewriter, xferOp);
659 memref::StoreOp::create(rewriter, loc, xferOp.getVector(),
661 auto loadedVec = memref::LoadOp::create(rewriter, loc, buffers.dataBuffer);
663 xferOp.getValueToStoreMutable().assign(loadedVec);
664 xferOp->setAttr(kPassLabel, rewriter.
getUnitAttr());
667 if (xferOp.getMask()) {
669 xferOp.getMaskMutable().assign(buffers.maskBuffer);
704struct DecomposePrintOpConversion :
public VectorToSCFPattern<vector::PrintOp> {
705 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
706 LogicalResult matchAndRewrite(vector::PrintOp
printOp,
711 VectorType vectorType = dyn_cast<VectorType>(
printOp.getPrintType());
721 if (vectorType.getRank() > 1 && vectorType.isScalable())
725 auto value =
printOp.getSource();
727 if (
auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
731 auto width = intTy.getWidth();
732 auto legalWidth = llvm::NextPowerOf2(std::max(8u, width) - 1);
733 auto legalIntTy = IntegerType::get(rewriter.
getContext(), legalWidth,
734 intTy.getSignedness());
736 auto signlessSourceVectorType =
737 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
738 auto signlessTargetVectorType =
739 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
740 auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
741 value = vector::BitCastOp::create(rewriter, loc, signlessSourceVectorType,
743 if (value.
getType() != signlessTargetVectorType) {
744 if (width == 1 || intTy.isUnsigned())
745 value = arith::ExtUIOp::create(rewriter, loc,
746 signlessTargetVectorType, value);
748 value = arith::ExtSIOp::create(rewriter, loc,
749 signlessTargetVectorType, value);
751 value = vector::BitCastOp::create(rewriter, loc, targetVectorType, value);
752 vectorType = targetVectorType;
755 auto scalableDimensions = vectorType.getScalableDims();
756 auto shape = vectorType.getShape();
757 constexpr int64_t singletonShape[] = {1};
758 if (vectorType.getRank() == 0)
759 shape = singletonShape;
761 if (vectorType.getRank() != 1) {
765 auto flatVectorType =
766 VectorType::get({flatLength}, vectorType.getElementType());
767 value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value);
770 vector::PrintOp firstClose;
772 for (
unsigned d = 0; d <
shape.size(); d++) {
778 if (!scalableDimensions.empty() && scalableDimensions[d]) {
779 auto vscale = vector::VectorScaleOp::create(rewriter, loc,
781 upperBound = arith::MulIOp::create(rewriter, loc, upperBound, vscale);
783 auto lastIndex = arith::SubIOp::create(rewriter, loc, upperBound, step);
786 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
788 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
790 rewriter, loc, vector::PrintPunctuation::Close);
794 auto loopIdx = loop.getInductionVar();
795 loopIndices.push_back(loopIdx);
799 auto notLastIndex = arith::CmpIOp::create(
800 rewriter, loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
801 scf::IfOp::create(rewriter, loc, notLastIndex,
803 vector::PrintOp::create(
804 builder, loc, vector::PrintPunctuation::Comma);
805 scf::YieldOp::create(builder, loc);
814 auto currentStride = 1;
815 for (
int d =
shape.size() - 1; d >= 0; d--) {
818 auto index = arith::MulIOp::create(rewriter, loc, stride, loopIndices[d]);
820 flatIndex = arith::AddIOp::create(rewriter, loc, flatIndex,
index);
823 currentStride *=
shape[d];
827 auto element = vector::ExtractOp::create(rewriter, loc, value, flatIndex);
828 vector::PrintOp::create(rewriter, loc, element,
829 vector::PrintPunctuation::NoPunctuation);
832 vector::PrintOp::create(rewriter, loc,
printOp.getPunctuation());
837 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
838 return IntegerType::get(intTy.getContext(), intTy.getWidth(),
839 IntegerType::Signless);
872template <
typename OpTy>
873struct TransferOpConversion :
public VectorToSCFPattern<OpTy> {
874 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
879 this->setHasBoundedRewriteRecursion();
882 static void getMaskBufferLoadIndices(OpTy xferOp,
Value castedMaskBuffer,
885 assert(xferOp.getMask() &&
"Expected transfer op to have mask");
891 Value maskBuffer = getMaskBuffer(xferOp);
894 if (
auto loadOp = dyn_cast<memref::LoadOp>(user)) {
896 loadIndices.append(prevIndices.begin(), prevIndices.end());
903 if (!xferOp.isBroadcastDim(0))
904 loadIndices.push_back(iv);
907 LogicalResult matchAndRewrite(OpTy xferOp,
909 if (!xferOp->hasAttr(kPassLabel))
911 xferOp,
"kPassLabel is present (progressing lowering in progress)");
915 Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
916 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.
getType());
917 FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
918 if (failed(castedDataType))
920 "Failed to unpack one vector dim.");
922 auto castedDataBuffer =
923 vector::TypeCastOp::create(locB, *castedDataType, dataBuffer);
926 Value castedMaskBuffer;
927 if (xferOp.getMask()) {
928 Value maskBuffer = getMaskBuffer(xferOp);
929 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
935 castedMaskBuffer = maskBuffer;
939 auto maskBufferType = cast<MemRefType>(maskBuffer.
getType());
940 MemRefType castedMaskType = *unpackOneDim(maskBufferType);
942 vector::TypeCastOp::create(locB, castedMaskType, maskBuffer);
949 locB, castedDataType->getDimSize(castedDataType->getRank() - 1));
953 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
956 auto result = scf::ForOp::create(
961 auto result = generateInBoundsCheck(
962 b, xferOp, iv, unpackedDim(xferOp),
967 OpTy newXfer = Strategy<OpTy>::rewriteOp(
968 b, this->options, xferOp, castedDataBuffer, iv, loopState);
974 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
975 xferOp.getMaskType().getRank() > 1)) {
977 b.setInsertionPoint(newXfer);
980 getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
982 auto mask = memref::LoadOp::create(
b, loc, castedMaskBuffer,
985 newXfer.getMaskMutable().assign(mask);
989 return loopState.empty() ?
Value() : newXfer->getResult(0);
993 return Strategy<OpTy>::handleOutOfBoundsDim(
994 b, xferOp, castedDataBuffer, iv, loopState);
997 maybeYieldValue(
b, loc, !loopState.empty(),
result);
1000 Strategy<OpTy>::cleanup(rewriter, xferOp,
result);
1007template <
typename VscaleConstantBuilder>
1008static FailureOr<SmallVector<OpFoldResult>>
1009getMaskDimSizes(
Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1012 if (
auto createMaskOp = mask.
getDefiningOp<vector::CreateMaskOp>()) {
1013 return llvm::map_to_vector(createMaskOp.getOperands(), [](
Value dimSize) {
1014 return OpFoldResult(dimSize);
1017 if (
auto constantMask = mask.
getDefiningOp<vector::ConstantMaskOp>()) {
1019 VectorType maskType = constantMask.getVectorType();
1020 auto indexType = IndexType::get(mask.
getContext());
1021 return llvm::map_to_vector(
1022 constantMask.getMaskDimSizes(), [&](
int64_t dimSize) {
1024 if (maskType.getScalableDims()[dimIdx++])
1025 return OpFoldResult(createVscaleMultiple(dimSize));
1026 return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1069struct ScalableTransposeTransferWriteConversion
1070 : VectorToSCFPattern<vector::TransferWriteOp> {
1071 using VectorToSCFPattern::VectorToSCFPattern;
1073 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1075 if (failed(checkLowerTensors(writeOp, rewriter)))
1078 VectorType vectorType = writeOp.getVectorType();
1085 writeOp,
"expected vector of the form vector<[N]xMxty>");
1088 auto permutationMap = writeOp.getPermutationMap();
1089 if (!permutationMap.isIdentity()) {
1091 writeOp,
"non-identity permutations are unsupported (lower first)");
1097 if (!writeOp.isDimInBounds(0)) {
1099 writeOp,
"out-of-bounds dims are unsupported (use masking)");
1103 auto transposeOp =
vector.getDefiningOp<vector::TransposeOp>();
1109 auto loc = writeOp.getLoc();
1110 auto createVscaleMultiple =
1113 auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1114 if (failed(maskDims)) {
1116 "failed to resolve mask dims");
1119 int64_t fixedDimSize = vectorType.getDimSize(1);
1120 auto fixedDimOffsets = llvm::seq(fixedDimSize);
1123 auto transposeSource = transposeOp.getVector();
1125 llvm::map_to_vector(fixedDimOffsets, [&](
int64_t idx) ->
Value {
1126 return vector::ExtractOp::create(rewriter, loc, transposeSource, idx);
1133 ?
Value(createVscaleMultiple(vectorType.getDimSize(0)))
1139 Value sliceMask =
nullptr;
1140 if (!maskDims->empty()) {
1141 sliceMask = vector::CreateMaskOp::create(
1142 rewriter, loc, sliceType.clone(rewriter.
getI1Type()),
1146 Value initDest = isTensorOp(writeOp) ? writeOp.getBase() :
Value{};
1148 auto result = scf::ForOp::create(
1149 rewriter, loc, lb,
ub, step, initLoopArgs,
1157 llvm::map_to_vector(fixedDimOffsets, [&](
int64_t idx) ->
Value {
1158 return vector::ExtractOp::create(
1159 b, loc, transposeSourceSlices[idx], iv);
1161 auto sliceVec = vector::FromElementsOp::create(
b, loc, sliceType,
1166 loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
1167 auto newWriteOp = vector::TransferWriteOp::create(
1168 b, loc, sliceVec, dest, xferIndices,
1171 newWriteOp.getMaskMutable().assign(sliceMask);
1174 scf::YieldOp::create(
b, loc,
1176 : newWriteOp.getResult());
1179 if (isTensorOp(writeOp))
1194template <
typename OpTy>
1195static void maybeAssignMask(
OpBuilder &
b, OpTy xferOp, OpTy newXferOp,
1197 if (!xferOp.getMask())
1200 if (xferOp.isBroadcastDim(0)) {
1203 newXferOp.getMaskMutable().assign(xferOp.getMask());
1207 if (xferOp.getMaskType().getRank() > 1) {
1210 b.setInsertionPoint(newXferOp);
1214 auto newMask = vector::ExtractOp::create(
b, loc, xferOp.getMask(),
indices);
1215 newXferOp.getMaskMutable().assign(newMask);
1251struct UnrollTransferReadConversion
1252 :
public VectorToSCFPattern<TransferReadOp> {
1253 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1258 setHasBoundedRewriteRecursion();
1264 TransferReadOp xferOp)
const {
1265 if (
auto insertOp = getInsertOp(xferOp))
1266 return insertOp.getDest();
1268 return vector::BroadcastOp::create(rewriter, loc, xferOp.getVectorType(),
1269 xferOp.getPadding());
1274 vector::InsertOp getInsertOp(TransferReadOp xferOp)
const {
1275 if (xferOp->hasOneUse()) {
1277 if (
auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1281 return vector::InsertOp();
1286 void getInsertionIndices(TransferReadOp xferOp,
1288 if (
auto insertOp = getInsertOp(xferOp)) {
1289 auto pos = insertOp.getMixedPosition();
1290 indices.append(pos.begin(), pos.end());
1296 LogicalResult matchAndRewrite(TransferReadOp xferOp,
1298 if (xferOp.getVectorType().getRank() <=
options.targetRank)
1300 xferOp,
"vector rank is less or equal to target rank");
1301 if (failed(checkLowerTensors(xferOp, rewriter)))
1303 if (xferOp.getVectorType().getElementType() !=
1304 xferOp.getShapedType().getElementType())
1306 xferOp,
"not yet supported: element type mismatch");
1307 auto xferVecType = xferOp.getVectorType();
1308 if (xferVecType.getScalableDims()[0]) {
1310 xferOp,
"scalable dimensions cannot be unrolled at compile time");
1313 auto insertOp = getInsertOp(xferOp);
1314 auto vec = buildResultVector(rewriter, xferOp);
1315 auto vecType = dyn_cast<VectorType>(vec.getType());
1319 int64_t dimSize = xferVecType.getShape()[0];
1323 for (
int64_t i = 0; i < dimSize; ++i) {
1328 vec = generateInBoundsCheck(
1329 rewriter, xferOp, iv, unpackedDim(xferOp),
TypeRange(vecType),
1338 getInsertionIndices(xferOp, insertionIndices);
1341 auto inBoundsAttr = dropFirstElem(
b, xferOp.getInBoundsAttr());
1343 auto newXferOp = vector::TransferReadOp::create(
1344 b, loc, newXferVecType, xferOp.getBase(), xferIndices,
1345 AffineMapAttr::get(unpackedPermutationMap(
b, xferOp)),
1346 xferOp.getPadding(),
Value(), inBoundsAttr);
1347 maybeAssignMask(
b, xferOp, newXferOp, i);
1349 Value valToInser = newXferOp.getResult();
1350 if (newXferVecType.getRank() == 0) {
1353 valToInser = vector::ExtractOp::create(
b, loc, valToInser,
1356 return vector::InsertOp::create(
b, loc, valToInser, vec,
1404struct UnrollTransferWriteConversion
1405 :
public VectorToSCFPattern<TransferWriteOp> {
1406 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1411 setHasBoundedRewriteRecursion();
1415 Value getDataVector(TransferWriteOp xferOp)
const {
1416 if (
auto extractOp = getExtractOp(xferOp))
1417 return extractOp.getSource();
1418 return xferOp.getVector();
1422 vector::ExtractOp getExtractOp(TransferWriteOp xferOp)
const {
1423 if (
auto *op = xferOp.getVector().getDefiningOp())
1424 return dyn_cast<vector::ExtractOp>(op);
1425 return vector::ExtractOp();
1430 void getExtractionIndices(TransferWriteOp xferOp,
1432 if (
auto extractOp = getExtractOp(xferOp)) {
1433 auto pos = extractOp.getMixedPosition();
1434 indices.append(pos.begin(), pos.end());
1440 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1442 VectorType inputVectorTy = xferOp.getVectorType();
1444 if (inputVectorTy.getRank() <=
options.targetRank)
1447 if (failed(checkLowerTensors(xferOp, rewriter)))
1450 if (inputVectorTy.getElementType() !=
1451 xferOp.getShapedType().getElementType())
1454 auto vec = getDataVector(xferOp);
1455 if (inputVectorTy.getScalableDims()[0]) {
1460 int64_t dimSize = inputVectorTy.getShape()[0];
1461 Value source = xferOp.getBase();
1462 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() :
Type();
1466 for (
int64_t i = 0; i < dimSize; ++i) {
1469 auto updatedSource = generateInBoundsCheck(
1470 rewriter, xferOp, iv, unpackedDim(xferOp),
1480 getExtractionIndices(xferOp, extractionIndices);
1481 extractionIndices.push_back(
b.getI64IntegerAttr(i));
1484 vector::ExtractOp::create(
b, loc, vec, extractionIndices);
1485 auto inBoundsAttr = dropFirstElem(
b, xferOp.getInBoundsAttr());
1487 if (inputVectorTy.getRank() == 1) {
1491 xferVec = vector::BroadcastOp::create(
1492 b, loc, VectorType::get({}, extracted.getType()), extracted);
1494 xferVec = extracted;
1496 auto newXferOp = vector::TransferWriteOp::create(
1497 b, loc, sourceType, xferVec, source, xferIndices,
1498 AffineMapAttr::get(unpackedPermutationMap(
b, xferOp)),
Value(),
1501 maybeAssignMask(
b, xferOp, newXferOp, i);
1503 return isTensorOp(xferOp) ? newXferOp->getResult(0) :
Value();
1507 return isTensorOp(xferOp) ? source :
Value();
1510 if (isTensorOp(xferOp))
1511 source = updatedSource;
1514 if (isTensorOp(xferOp))