75template <
typename XferOp,
typename ExtractOrInsertOp>
78 ExtractOrInsertOp extractOrInsertSliceOp) {
79 if (xferOp.hasOutOfBoundsDim())
83 if (!extractOrInsertSliceOp.hasUnitStride()) {
85 xferOp,
"non-1 stride insert/extract, requires keeping track of "
86 "strides, this may result in needing to insert "
87 "vector.insert_strided_slice/extract_strided_slice ops");
93TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
94 vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
101 LogicalResult preconditionResult =
104 if (failed(preconditionResult))
108 readOp.getIndices().end());
111 rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
112 extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
115 Operation *newOp = vector::TransferReadOp::create(
116 rewriter, readOp.getLoc(), readOp.getVectorType(),
117 extractSliceOp.getSource(), sourceIndices,
119 readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
120 extractSliceOp.getDroppedDims())),
122 Value(), readOp.getInBoundsAttr());
128LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
129 tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter)
const {
131 .template getDefiningOp<vector::TransferWriteOp>();
135 LogicalResult preconditionResult =
138 if (
failed(preconditionResult))
139 return preconditionResult;
141 if (!doesTransferWriteCoverInsertSlice(writeOp))
143 insertSliceOp,
"transfer_write does not cover insert_slice");
145 SmallVector<Value>
indices(writeOp.getIndices().begin(),
146 writeOp.getIndices().end());
147 SmallVector<Value> sourceIndices;
149 rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
150 insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(),
indices,
154 insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
156 insertSliceOp.getDestType().getRank(),
157 insertSliceOp.getDroppedDims())),
158 writeOp.getInBoundsAttr());
163bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
164 vector::TransferWriteOp writeOp) {
165 if (writeOp.getShapedType().hasStaticShape())
166 return llvm::equal(writeOp.getVectorType().getShape(),
167 writeOp.getShapedType().getShape());
174template <
typename OpTy>
180 auto sourceInsertSliceOp =
181 insertSliceOp.getSource()
182 .template getDefiningOp<tensor::InsertSliceOp>();
183 if (!sourceInsertSliceOp)
187 if (!insertSliceOp.hasUnitStride()) {
189 "requires unit strides");
191 if (!sourceInsertSliceOp.hasUnitStride()) {
193 "requires unit strides");
197 llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
198 for (
int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
201 if (insertSliceOp.getMixedSizes()[d] !=
202 sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
205 "requires matching sizes to fold, otherwise a copy is needed");
215 sourceInsertSliceOp.getMixedSizes(),
216 droppedDims, resolvedSizes);
221 if (isa<mlir::ParallelCombiningOpInterface>(insertSliceOp.getOperation())) {
231 rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
232 insertSliceOp.getMixedStrides(), droppedDims,
233 sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
239 insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
241 insertSliceOp.getMixedStrides());
256 patterns.add<TransferReadOfExtractSliceOpFolder,
257 InsertSliceOfTransferWriteOpFolder>(
patterns.getContext());
266struct FoldTensorSubsetOpsPass final
268 FoldTensorSubsetOpsPass> {
269 void runOnOperation()
override;
274void FoldTensorSubsetOpsPass::runOnOperation() {