MLIR 22.0.0git
VectorToSCF.cpp
Go to the documentation of this file.
1//===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector transfer operations to SCF.
10//
11//===----------------------------------------------------------------------===//
12
13#include <numeric>
14#include <optional>
15
17
26#include "mlir/IR/Builders.h"
27#include "mlir/Pass/Pass.h"
29#include "llvm/ADT/STLExtras.h"
30
31namespace mlir {
32#define GEN_PASS_DEF_CONVERTVECTORTOSCF
33#include "mlir/Conversion/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37using vector::TransferReadOp;
38using vector::TransferWriteOp;
39
40namespace {
41
42/// Attribute name used for labeling transfer ops during progressive lowering.
43static const char kPassLabel[] = "__vector_to_scf_lowering__";
44
45/// Return true if this transfer op operates on a source tensor.
46static bool isTensorOp(VectorTransferOpInterface xferOp) {
47 if (isa<RankedTensorType>(xferOp.getShapedType())) {
48 if (isa<vector::TransferWriteOp>(xferOp)) {
49 // TransferWriteOps on tensors have a result.
50 assert(xferOp->getNumResults() > 0);
51 }
52 return true;
53 }
54 return false;
55}
56
57/// Patterns that inherit from this struct have access to
58/// VectorTransferToSCFOptions.
59template <typename OpTy>
60struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
61 explicit VectorToSCFPattern(MLIRContext *context,
62 VectorTransferToSCFOptions opt)
63 : OpRewritePattern<OpTy>(context), options(opt) {}
64
65 LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
66 PatternRewriter &rewriter) const {
67 if (isTensorOp(xferOp) && !options.lowerTensors) {
68 return rewriter.notifyMatchFailure(
69 xferOp, "lowering tensor transfers is disabled");
70 }
71 return success();
72 }
73
74 VectorTransferToSCFOptions options;
75};
76
77/// Given a vector transfer op, calculate which dimension of the `source`
78/// memref should be unpacked in the next application of TransferOpConversion.
79/// A return value of std::nullopt indicates a broadcast.
80template <typename OpTy>
81static std::optional<int64_t> unpackedDim(OpTy xferOp) {
82 // TODO: support 0-d corner case.
83 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
84 auto map = xferOp.getPermutationMap();
85 if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
86 return expr.getPosition();
87 }
88 assert(xferOp.isBroadcastDim(0) &&
89 "Expected AffineDimExpr or AffineConstantExpr");
90 return std::nullopt;
91}
92
93/// Compute the permutation map for the new (N-1)-D vector transfer op. This
94/// map is identical to the current permutation map, but the first result is
95/// omitted.
96template <typename OpTy>
97static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
98 // TODO: support 0-d corner case.
99 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
100 auto map = xferOp.getPermutationMap();
101 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
102 b.getContext());
103}
104
105/// Calculate the indices for the new vector transfer op.
106///
107/// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
108/// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
109/// ^^^^^^
110/// `iv` is the iteration variable of the (new) surrounding loop.
111template <typename OpTy>
112static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
114 typename OpTy::Adaptor adaptor(xferOp);
115 // Corresponding memref dim of the vector dim that is unpacked.
116 auto dim = unpackedDim(xferOp);
117 auto prevIndices = adaptor.getIndices();
118 indices.append(prevIndices.begin(), prevIndices.end());
119
120 Location loc = xferOp.getLoc();
121 bool isBroadcast = !dim.has_value();
122 if (!isBroadcast) {
123 AffineExpr d0, d1;
124 bindDims(xferOp.getContext(), d0, d1);
125 Value offset = adaptor.getIndices()[*dim];
126 indices[*dim] =
127 affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
128 }
129}
130
131static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
132 Value value) {
133 if (hasRetVal) {
134 assert(value && "Expected non-empty value");
135 scf::YieldOp::create(b, loc, value);
136 } else {
137 scf::YieldOp::create(b, loc);
138 }
139}
140
141/// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
142/// is set to true. No such check is generated under following circumstances:
143/// * xferOp does not have a mask.
144/// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
145/// computed and attached to the new transfer op in the pattern.)
146/// * The to-be-unpacked dim of xferOp is a broadcast.
147template <typename OpTy>
148static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
149 if (!xferOp.getMask())
150 return Value();
151 if (xferOp.getMaskType().getRank() != 1)
152 return Value();
153 if (xferOp.isBroadcastDim(0))
154 return Value();
155
156 Location loc = xferOp.getLoc();
157 return vector::ExtractOp::create(b, loc, xferOp.getMask(), iv);
158}
159
160/// Helper function TransferOpConversion and TransferOp1dConversion.
161/// Generate an in-bounds check if the transfer op may go out-of-bounds on the
162/// specified dimension `dim` with the loop iteration variable `iv`.
163/// E.g., when unpacking dimension 0 from:
164/// ```
165/// %vec = vector.transfer_read %A[%a, %b] %cst
166/// : vector<5x4xf32>, memref<?x?xf32>
167/// ```
168/// An if check similar to this will be generated inside the loop:
169/// ```
170/// %d = memref.dim %A, %c0 : memref<?x?xf32>
171/// if (%a + iv < %d) {
172/// (in-bounds case)
173/// } else {
174/// (out-of-bounds case)
175/// }
176/// ```
177///
178/// If the transfer is 1D and has a mask, this function generates a more complex
179/// check also accounts for potentially masked out elements.
180///
181/// This function variant returns the value returned by `inBoundsCase` or
182/// `outOfBoundsCase`. The MLIR type of the return value must be specified in
183/// `resultTypes`.
184template <typename OpTy>
185static Value generateInBoundsCheck(
186 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
187 TypeRange resultTypes,
188 function_ref<Value(OpBuilder &, Location)> inBoundsCase,
189 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
190 bool hasRetVal = !resultTypes.empty();
191 Value cond; // Condition to be built...
192
193 // Condition check 1: Access in-bounds?
194 bool isBroadcast = !dim; // No in-bounds check for broadcasts.
195 Location loc = xferOp.getLoc();
196 ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
197 if (!xferOp.isDimInBounds(0) && !isBroadcast) {
198 Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.getBase(), *dim);
199 AffineExpr d0, d1;
200 bindDims(xferOp.getContext(), d0, d1);
201 Value base = xferOp.getIndices()[*dim];
202 Value memrefIdx =
203 affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
204 cond = arith::CmpIOp::create(lb, arith::CmpIPredicate::sgt, memrefDim,
205 memrefIdx);
206 }
207
208 // Condition check 2: Masked in?
209 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
210 if (cond)
211 cond = arith::AndIOp::create(lb, cond, maskCond);
212 else
213 cond = maskCond;
214 }
215
216 // If the condition is non-empty, generate an SCF::IfOp.
217 if (cond) {
218 auto check = scf::IfOp::create(
219 lb, cond,
220 /*thenBuilder=*/
221 [&](OpBuilder &b, Location loc) {
222 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
223 },
224 /*elseBuilder=*/
225 [&](OpBuilder &b, Location loc) {
226 if (outOfBoundsCase) {
227 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
228 } else {
229 scf::YieldOp::create(b, loc);
230 }
231 });
232
233 return hasRetVal ? check.getResult(0) : Value();
234 }
235
236 // Condition is empty, no need for an SCF::IfOp.
237 return inBoundsCase(b, loc);
238}
239
240/// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
241/// a return value. Consequently, this function does not have a return value.
242template <typename OpTy>
243static void generateInBoundsCheck(
244 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
245 function_ref<void(OpBuilder &, Location)> inBoundsCase,
246 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
247 generateInBoundsCheck(
248 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
249 /*inBoundsCase=*/
250 [&](OpBuilder &b, Location loc) {
251 inBoundsCase(b, loc);
252 return Value();
253 },
254 /*outOfBoundsCase=*/
255 [&](OpBuilder &b, Location loc) {
256 if (outOfBoundsCase)
257 outOfBoundsCase(b, loc);
258 return Value();
259 });
260}
261
262/// Given an ArrayAttr, return a copy where the first element is dropped.
263static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
264 if (!attr)
265 return attr;
266 return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
267}
268
269/// Add the pass label to a vector transfer op if its rank is not the target
270/// rank.
271template <typename OpTy>
272static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
273 unsigned targetRank) {
274 if (newXferOp.getVectorType().getRank() > targetRank)
275 newXferOp->setAttr(kPassLabel, b.getUnitAttr());
276}
277
278namespace lowering_n_d {
279
280/// Helper data structure for data and mask buffers.
281struct BufferAllocs {
282 Value dataBuffer;
283 Value maskBuffer;
284};
285
286// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
287static Operation *getAutomaticAllocationScope(Operation *op) {
288 Operation *scope =
290 assert(scope && "Expected op to be inside automatic allocation scope");
291 return scope;
292}
293
294/// Allocate temporary buffers for data (vector) and mask (if present).
295template <typename OpTy>
296static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
297 Location loc = xferOp.getLoc();
299 Operation *scope = getAutomaticAllocationScope(xferOp);
300 assert(scope->getNumRegions() == 1 &&
301 "AutomaticAllocationScope with >1 regions");
302 b.setInsertionPointToStart(&scope->getRegion(0).front());
303
304 BufferAllocs result;
305 auto bufferType = MemRefType::get({}, xferOp.getVectorType());
306 result.dataBuffer = memref::AllocaOp::create(b, loc, bufferType);
307
308 if (xferOp.getMask()) {
309 auto maskType = MemRefType::get({}, xferOp.getMask().getType());
310 auto maskBuffer = memref::AllocaOp::create(b, loc, maskType);
311 b.setInsertionPoint(xferOp);
312 memref::StoreOp::create(b, loc, xferOp.getMask(), maskBuffer);
313 result.maskBuffer =
314 memref::LoadOp::create(b, loc, maskBuffer, ValueRange());
315 }
316
317 return result;
318}
319
320/// Given a MemRefType with VectorType element type, unpack one dimension from
321/// the VectorType into the MemRefType.
322///
323/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
324static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
325 auto vectorType = dyn_cast<VectorType>(type.getElementType());
326 // Vectors with leading scalable dims are not supported.
327 // It may be possible to support these in future by using dynamic memref dims.
328 if (vectorType.getScalableDims().front())
329 return failure();
330 auto memrefShape = type.getShape();
331 SmallVector<int64_t, 8> newMemrefShape;
332 newMemrefShape.append(memrefShape.begin(), memrefShape.end());
333 newMemrefShape.push_back(vectorType.getDimSize(0));
334 return MemRefType::get(newMemrefShape,
335 VectorType::Builder(vectorType).dropDim(0));
336}
337
338/// Given a transfer op, find the memref from which the mask is loaded. This
339/// is similar to Strategy<TransferWriteOp>::getBuffer.
340template <typename OpTy>
341static Value getMaskBuffer(OpTy xferOp) {
342 assert(xferOp.getMask() && "Expected that transfer op has mask");
343 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
344 assert(loadOp && "Expected transfer op mask produced by LoadOp");
345 return loadOp.getMemRef();
346}
347
348/// Codegen strategy, depending on the operation.
349template <typename OpTy>
350struct Strategy;
351
352/// Code strategy for vector TransferReadOp.
353template <>
354struct Strategy<TransferReadOp> {
355 /// Find the StoreOp that is used for writing the current TransferReadOp's
356 /// result to the temporary buffer allocation.
357 static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
358 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
359 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
360 assert(storeOp && "Expected TransferReadOp result used by StoreOp");
361 return storeOp;
362 }
363
364 /// Find the temporary buffer allocation. All labeled TransferReadOps are
365 /// used like this, where %buf is either the buffer allocation or a type cast
366 /// of the buffer allocation:
367 /// ```
368 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
369 /// memref.store %vec, %buf[...] ...
370 /// ```
371 static Value getBuffer(TransferReadOp xferOp) {
372 return getStoreOp(xferOp).getMemRef();
373 }
374
375 /// Retrieve the indices of the current StoreOp that stores into the buffer.
376 static void getBufferIndices(TransferReadOp xferOp,
378 auto storeOp = getStoreOp(xferOp);
379 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
380 indices.append(prevIndices.begin(), prevIndices.end());
381 }
382
383 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
384 /// accesses on the to-be-unpacked dimension.
385 ///
386 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
387 /// variable `iv`.
388 /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
389 ///
390 /// E.g.:
391 /// ```
392 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
393 /// : memref<?x?x?xf32>, vector<4x3xf32>
394 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
395 /// ```
396 /// Is rewritten to:
397 /// ```
398 /// %casted = vector.type_cast %buf
399 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
400 /// for %j = 0 to 4 {
401 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
402 /// : memref<?x?x?xf32>, vector<3xf32>
403 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
404 /// }
405 /// ```
406 ///
407 /// Note: The loop and type cast are generated in TransferOpConversion.
408 /// The original TransferReadOp and store op are deleted in `cleanup`.
409 /// Note: The `mask` operand is set in TransferOpConversion.
410 static TransferReadOp rewriteOp(OpBuilder &b,
412 TransferReadOp xferOp, Value buffer, Value iv,
413 ValueRange /*loopState*/) {
414 SmallVector<Value, 8> storeIndices;
415 getBufferIndices(xferOp, storeIndices);
416 storeIndices.push_back(iv);
417
418 SmallVector<Value, 8> xferIndices;
419 getXferIndices(b, xferOp, iv, xferIndices);
420
421 Location loc = xferOp.getLoc();
422 auto bufferType = dyn_cast<ShapedType>(buffer.getType());
423 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
424 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
425 auto newXferOp = vector::TransferReadOp::create(
426 b, loc, vecType, xferOp.getBase(), xferIndices,
427 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
428 xferOp.getPadding(), Value(), inBoundsAttr);
429
430 maybeApplyPassLabel(b, newXferOp, options.targetRank);
431
432 memref::StoreOp::create(b, loc, newXferOp.getVector(), buffer,
433 storeIndices);
434 return newXferOp;
435 }
436
437 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
438 /// padding value to the temporary buffer.
439 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
440 Value buffer, Value iv,
441 ValueRange /*loopState*/) {
442 SmallVector<Value, 8> storeIndices;
443 getBufferIndices(xferOp, storeIndices);
444 storeIndices.push_back(iv);
445
446 Location loc = xferOp.getLoc();
447 auto bufferType = dyn_cast<ShapedType>(buffer.getType());
448 auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
449 auto vec =
450 vector::BroadcastOp::create(b, loc, vecType, xferOp.getPadding());
451 memref::StoreOp::create(b, loc, vec, buffer, storeIndices);
452
453 return Value();
454 }
455
456 /// Cleanup after rewriting the op.
457 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
458 scf::ForOp /*forOp*/) {
459 rewriter.eraseOp(getStoreOp(xferOp));
460 rewriter.eraseOp(xferOp);
461 }
462
463 /// Return the initial loop state for the generated scf.for loop.
464 static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
465};
466
467/// Codegen strategy for vector TransferWriteOp.
468template <>
469struct Strategy<TransferWriteOp> {
470 /// Find the temporary buffer allocation. All labeled TransferWriteOps are
471 /// used like this, where %buf is either the buffer allocation or a type cast
472 /// of the buffer allocation:
473 /// ```
474 /// %vec = memref.load %buf[...] ...
475 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
476 /// ```
477 static Value getBuffer(TransferWriteOp xferOp) {
478 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
479 assert(loadOp && "Expected transfer op vector produced by LoadOp");
480 return loadOp.getMemRef();
481 }
482
483 /// Retrieve the indices of the current LoadOp that loads from the buffer.
484 static void getBufferIndices(TransferWriteOp xferOp,
486 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
487 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
488 indices.append(prevIndices.begin(), prevIndices.end());
489 }
490
491 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
492 /// accesses on the to-be-unpacked dimension.
493 ///
494 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
495 /// using the loop iteration variable `iv`.
496 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
497 /// to memory.
498 ///
499 /// Note: For more details, see comments on Strategy<TransferReadOp>.
500 static TransferWriteOp rewriteOp(OpBuilder &b,
502 TransferWriteOp xferOp, Value buffer,
503 Value iv, ValueRange loopState) {
504 SmallVector<Value, 8> loadIndices;
505 getBufferIndices(xferOp, loadIndices);
506 loadIndices.push_back(iv);
507
508 SmallVector<Value, 8> xferIndices;
509 getXferIndices(b, xferOp, iv, xferIndices);
510
511 Location loc = xferOp.getLoc();
512 auto vec = memref::LoadOp::create(b, loc, buffer, loadIndices);
513 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
514 auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
515 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
516 auto newXferOp = vector::TransferWriteOp::create(
517 b, loc, type, vec, source, xferIndices,
518 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
519 inBoundsAttr);
520
521 maybeApplyPassLabel(b, newXferOp, options.targetRank);
522
523 return newXferOp;
524 }
525
526 /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
527 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
528 Value buffer, Value iv,
529 ValueRange loopState) {
530 return isTensorOp(xferOp) ? loopState[0] : Value();
531 }
532
533 /// Cleanup after rewriting the op.
534 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
535 scf::ForOp forOp) {
536 if (isTensorOp(xferOp)) {
537 assert(forOp->getNumResults() == 1 && "Expected one for loop result");
538 rewriter.replaceOp(xferOp, forOp->getResult(0));
539 } else {
540 rewriter.eraseOp(xferOp);
541 }
542 }
543
544 /// Return the initial loop state for the generated scf.for loop.
545 static Value initialLoopState(TransferWriteOp xferOp) {
546 return isTensorOp(xferOp) ? xferOp.getBase() : Value();
547 }
548};
549
550template <typename OpTy>
551static LogicalResult checkPrepareXferOp(OpTy xferOp, PatternRewriter &rewriter,
553 if (xferOp->hasAttr(kPassLabel))
554 return rewriter.notifyMatchFailure(
555 xferOp, "kPassLabel is present (vector-to-scf lowering in progress)");
556 if (xferOp.getVectorType().getRank() <= options.targetRank)
557 return rewriter.notifyMatchFailure(
558 xferOp, "xferOp vector rank <= transformation target rank");
559 if (xferOp.getVectorType().getScalableDims().front())
560 return rewriter.notifyMatchFailure(
561 xferOp, "Unpacking of the leading dimension into the memref is not yet "
562 "supported for scalable dims");
563 if (isTensorOp(xferOp) && !options.lowerTensors)
564 return rewriter.notifyMatchFailure(
565 xferOp, "Unpacking for tensors has been disabled.");
566 if (xferOp.getVectorType().getElementType() !=
567 xferOp.getShapedType().getElementType())
568 return rewriter.notifyMatchFailure(
569 xferOp, "Mismatching source and destination element types.");
570
571 return success();
572}
573
574/// Prepare a TransferReadOp for progressive lowering.
575///
576/// 1. Allocate a temporary buffer.
577/// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
578/// 3. Store the result of the TransferReadOp into the temporary buffer.
579/// 4. Load the result from the temporary buffer and replace all uses of the
580/// original TransferReadOp with this load.
581///
582/// E.g.:
583/// ```
584/// %vec = vector.transfer_read %A[%a, %b, %c], %cst
585/// : vector<5x4xf32>, memref<?x?x?xf32>
586/// ```
587/// is rewritten to:
588/// ```
589/// %0 = memref.alloca() : memref<vector<5x4xf32>>
590/// %1 = vector.transfer_read %A[%a, %b, %c], %cst
591/// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
592/// memref.store %1, %0[] : memref<vector<5x4xf32>>
593/// %vec = memref.load %0[] : memref<vector<5x4xf32>>
594/// ```
595///
596/// Note: A second temporary buffer may be allocated for the `mask` operand.
597struct PrepareTransferReadConversion
598 : public VectorToSCFPattern<TransferReadOp> {
599 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
600
601 LogicalResult matchAndRewrite(TransferReadOp xferOp,
602 PatternRewriter &rewriter) const override {
603 if (checkPrepareXferOp(xferOp, rewriter, options).failed())
604 return rewriter.notifyMatchFailure(
605 xferOp, "checkPrepareXferOp conditions not met!");
606
607 auto buffers = allocBuffers(rewriter, xferOp);
608 auto *newXfer = rewriter.clone(*xferOp.getOperation());
609 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
610 if (xferOp.getMask()) {
611 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
612 buffers.maskBuffer);
613 }
614
615 Location loc = xferOp.getLoc();
616 memref::StoreOp::create(rewriter, loc, newXfer->getResult(0),
617 buffers.dataBuffer);
618 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
619
620 return success();
621 }
622};
623
624/// Prepare a TransferWriteOp for progressive lowering.
625///
626/// 1. Allocate a temporary buffer.
627/// 2. Store the vector into the buffer.
628/// 3. Load the vector from the buffer again.
629/// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
630/// marking it eligible for progressive lowering via TransferOpConversion.
631///
632/// E.g.:
633/// ```
634/// vector.transfer_write %vec, %A[%a, %b, %c]
635/// : vector<5x4xf32>, memref<?x?x?xf32>
636/// ```
637/// is rewritten to:
638/// ```
639/// %0 = memref.alloca() : memref<vector<5x4xf32>>
640/// memref.store %vec, %0[] : memref<vector<5x4xf32>>
641/// %1 = memref.load %0[] : memref<vector<5x4xf32>>
642/// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
643/// : vector<5x4xf32>, memref<?x?x?xf32>
644/// ```
645///
646/// Note: A second temporary buffer may be allocated for the `mask` operand.
647struct PrepareTransferWriteConversion
648 : public VectorToSCFPattern<TransferWriteOp> {
649 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
650
651 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
652 PatternRewriter &rewriter) const override {
653 if (checkPrepareXferOp(xferOp, rewriter, options).failed())
654 return rewriter.notifyMatchFailure(
655 xferOp, "checkPrepareXferOp conditions not met!");
656
657 Location loc = xferOp.getLoc();
658 auto buffers = allocBuffers(rewriter, xferOp);
659 memref::StoreOp::create(rewriter, loc, xferOp.getVector(),
660 buffers.dataBuffer);
661 auto loadedVec = memref::LoadOp::create(rewriter, loc, buffers.dataBuffer);
662 rewriter.modifyOpInPlace(xferOp, [&]() {
663 xferOp.getValueToStoreMutable().assign(loadedVec);
664 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
665 });
666
667 if (xferOp.getMask()) {
668 rewriter.modifyOpInPlace(xferOp, [&]() {
669 xferOp.getMaskMutable().assign(buffers.maskBuffer);
670 });
671 }
672
673 return success();
674 }
675};
676
677/// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows
678/// printing both 1D scalable vectors and n-D fixed size vectors.
679///
680/// E.g.:
681/// ```
682/// vector.print %v : vector<[4]xi32>
683/// ```
684/// is rewritten to:
685/// ```
686/// %c0 = arith.constant 0 : index
687/// %c4 = arith.constant 4 : index
688/// %c1 = arith.constant 1 : index
689/// %vscale = vector.vscale
690/// %length = arith.muli %vscale, %c4 : index
691/// %lastIndex = arith.subi %length, %c1 : index
692/// vector.print punctuation <open>
693/// scf.for %i = %c0 to %length step %c1 {
694/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32>
695/// vector.print %el : i32 punctuation <no_punctuation>
696/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
697/// scf.if %notLastIndex {
698/// vector.print punctuation <comma>
699/// }
700/// }
701/// vector.print punctuation <close>
702/// vector.print
703/// ```
704struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
705 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
706 LogicalResult matchAndRewrite(vector::PrintOp printOp,
707 PatternRewriter &rewriter) const override {
708 if (!printOp.getSource())
709 return failure();
710
711 VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
712 if (!vectorType)
713 return failure();
714
715 // Currently >= 2D scalable vectors are not supported.
716 // These can't be lowered to LLVM (as LLVM does not support scalable vectors
717 // of scalable vectors), and due to limitations of current ops can't be
718 // indexed with SSA values or flattened. This may change after
719 // https://reviews.llvm.org/D155034, though there still needs to be a path
720 // for lowering to LLVM.
721 if (vectorType.getRank() > 1 && vectorType.isScalable())
722 return failure();
723
724 auto loc = printOp.getLoc();
725 auto value = printOp.getSource();
726
727 if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
728 // Oddly sized integers are (somewhat) buggy on a lot of backends, so to
729 // avoid issues extend them to a more standard size.
730 // https://github.com/llvm/llvm-project/issues/30613
731 auto width = intTy.getWidth();
732 auto legalWidth = llvm::NextPowerOf2(std::max(8u, width) - 1);
733 auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth,
734 intTy.getSignedness());
735 // arith can only take signless integers, so we must cast back and forth.
736 auto signlessSourceVectorType =
737 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
738 auto signlessTargetVectorType =
739 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
740 auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
741 value = vector::BitCastOp::create(rewriter, loc, signlessSourceVectorType,
742 value);
743 if (value.getType() != signlessTargetVectorType) {
744 if (width == 1 || intTy.isUnsigned())
745 value = arith::ExtUIOp::create(rewriter, loc,
746 signlessTargetVectorType, value);
747 else
748 value = arith::ExtSIOp::create(rewriter, loc,
749 signlessTargetVectorType, value);
750 }
751 value = vector::BitCastOp::create(rewriter, loc, targetVectorType, value);
752 vectorType = targetVectorType;
753 }
754
755 auto scalableDimensions = vectorType.getScalableDims();
756 auto shape = vectorType.getShape();
757 constexpr int64_t singletonShape[] = {1};
758 if (vectorType.getRank() == 0)
759 shape = singletonShape;
760
761 if (vectorType.getRank() != 1) {
762 // Flatten n-D vectors to 1D. This is done to allow indexing with a
763 // non-constant value.
764 int64_t flatLength = llvm::product_of(shape);
765 auto flatVectorType =
766 VectorType::get({flatLength}, vectorType.getElementType());
767 value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value);
768 }
769
770 vector::PrintOp firstClose;
771 SmallVector<Value, 8> loopIndices;
772 for (unsigned d = 0; d < shape.size(); d++) {
773 // Setup loop bounds and step.
774 Value lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
775 Value upperBound =
776 arith::ConstantIndexOp::create(rewriter, loc, shape[d]);
777 Value step = arith::ConstantIndexOp::create(rewriter, loc, 1);
778 if (!scalableDimensions.empty() && scalableDimensions[d]) {
779 auto vscale = vector::VectorScaleOp::create(rewriter, loc,
780 rewriter.getIndexType());
781 upperBound = arith::MulIOp::create(rewriter, loc, upperBound, vscale);
782 }
783 auto lastIndex = arith::SubIOp::create(rewriter, loc, upperBound, step);
784
785 // Create a loop to print the elements surrounded by parentheses.
786 vector::PrintOp::create(rewriter, loc, vector::PrintPunctuation::Open);
787 auto loop =
788 scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
789 auto printClose = vector::PrintOp::create(
790 rewriter, loc, vector::PrintPunctuation::Close);
791 if (!firstClose)
792 firstClose = printClose;
793
794 auto loopIdx = loop.getInductionVar();
795 loopIndices.push_back(loopIdx);
796
797 // Print a comma after all but the last element.
798 rewriter.setInsertionPointToStart(loop.getBody());
799 auto notLastIndex = arith::CmpIOp::create(
800 rewriter, loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
801 scf::IfOp::create(rewriter, loc, notLastIndex,
802 [&](OpBuilder &builder, Location loc) {
803 vector::PrintOp::create(
804 builder, loc, vector::PrintPunctuation::Comma);
805 scf::YieldOp::create(builder, loc);
806 });
807
808 rewriter.setInsertionPointToStart(loop.getBody());
809 }
810
811 // Compute the flattened index.
812 // Note: For the > rank 1 vectors this assumes non-scalable.
813 Value flatIndex;
814 auto currentStride = 1;
815 for (int d = shape.size() - 1; d >= 0; d--) {
816 auto stride =
817 arith::ConstantIndexOp::create(rewriter, loc, currentStride);
818 auto index = arith::MulIOp::create(rewriter, loc, stride, loopIndices[d]);
819 if (flatIndex)
820 flatIndex = arith::AddIOp::create(rewriter, loc, flatIndex, index);
821 else
822 flatIndex = index;
823 currentStride *= shape[d];
824 }
825
826 // Print the scalar elements in the inner most loop.
827 auto element = vector::ExtractOp::create(rewriter, loc, value, flatIndex);
828 vector::PrintOp::create(rewriter, loc, element,
829 vector::PrintPunctuation::NoPunctuation);
830
831 rewriter.setInsertionPointAfter(firstClose);
832 vector::PrintOp::create(rewriter, loc, printOp.getPunctuation());
833 rewriter.eraseOp(printOp);
834 return success();
835 }
836
837 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
838 return IntegerType::get(intTy.getContext(), intTy.getWidth(),
839 IntegerType::Signless);
840 };
841};
842
843/// Progressive lowering of vector transfer ops: Unpack one dimension.
844///
845/// 1. Unpack one dimension from the current buffer type and cast the buffer
846/// to that new type. E.g.:
847/// ```
848/// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
849/// vector.transfer_write %vec ...
850/// ```
851/// The following cast is generated:
852/// ```
853/// %casted = vector.type_cast %0
854/// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
855/// ```
856/// 2. Generate a for loop and rewrite the transfer op according to the
857/// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
858/// out-of-bounds, generate an if-check and handle both cases separately.
859/// 3. Clean up according to the corresponding Strategy<OpTy>.
860///
861/// Note: If the transfer op is a TransferWriteOp and operates on a tensor
862/// source (as opposed to a memref source), then each iteration of the generated
863/// scf.for loop yields the new tensor value. E.g.:
864/// ```
865/// %result = scf.for i = 0 to 5 {
866/// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
867/// %1 = vector.transfer_write %0, %source[...]
868/// : vector<4x3xf32>, tensor<5x4x3xf32>
869/// scf.yield %1 : tensor<5x4x3xf32>
870/// }
871/// ```
872template <typename OpTy>
873struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
874 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
875
876 void initialize() {
877 // This pattern recursively unpacks one dimension at a time. The recursion
878 // bounded as the rank is strictly decreasing.
879 this->setHasBoundedRewriteRecursion();
880 }
881
882 static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
883 SmallVectorImpl<Value> &loadIndices,
884 Value iv) {
885 assert(xferOp.getMask() && "Expected transfer op to have mask");
886
887 // Add load indices from the previous iteration.
888 // The mask buffer depends on the permutation map, which makes determining
889 // the indices quite complex, so this is why we need to "look back" to the
890 // previous iteration to find the right indices.
891 Value maskBuffer = getMaskBuffer(xferOp);
892 for (Operation *user : maskBuffer.getUsers()) {
893 // If there is no previous load op, then the indices are empty.
894 if (auto loadOp = dyn_cast<memref::LoadOp>(user)) {
895 Operation::operand_range prevIndices = loadOp.getIndices();
896 loadIndices.append(prevIndices.begin(), prevIndices.end());
897 break;
898 }
899 }
900
901 // In case of broadcast: Use same indices to load from memref
902 // as before.
903 if (!xferOp.isBroadcastDim(0))
904 loadIndices.push_back(iv);
905 }
906
907 LogicalResult matchAndRewrite(OpTy xferOp,
908 PatternRewriter &rewriter) const override {
909 if (!xferOp->hasAttr(kPassLabel))
910 return rewriter.notifyMatchFailure(
911 xferOp, "kPassLabel is present (progressing lowering in progress)");
912
913 // Find and cast data buffer. How the buffer can be found depends on OpTy.
914 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
915 Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
916 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
917 FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
918 if (failed(castedDataType))
919 return rewriter.notifyMatchFailure(xferOp,
920 "Failed to unpack one vector dim.");
921
922 auto castedDataBuffer =
923 vector::TypeCastOp::create(locB, *castedDataType, dataBuffer);
924
925 // If the xferOp has a mask: Find and cast mask buffer.
926 Value castedMaskBuffer;
927 if (xferOp.getMask()) {
928 Value maskBuffer = getMaskBuffer(xferOp);
929 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
930 // Do not unpack a dimension of the mask, if:
931 // * To-be-unpacked transfer op dimension is a broadcast.
932 // * Mask is 1D, i.e., the mask cannot be further unpacked.
933 // (That means that all remaining dimensions of the transfer op must
934 // be broadcasted.)
935 castedMaskBuffer = maskBuffer;
936 } else {
937 // It's safe to assume the mask buffer can be unpacked if the data
938 // buffer was unpacked.
939 auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
940 MemRefType castedMaskType = *unpackOneDim(maskBufferType);
941 castedMaskBuffer =
942 vector::TypeCastOp::create(locB, castedMaskType, maskBuffer);
943 }
944 }
945
946 // Loop bounds and step.
947 auto lb = arith::ConstantIndexOp::create(locB, 0);
949 locB, castedDataType->getDimSize(castedDataType->getRank() - 1));
950 auto step = arith::ConstantIndexOp::create(locB, 1);
951 // TransferWriteOps that operate on tensors return the modified tensor and
952 // require a loop state.
953 auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
954
955 // Generate for loop.
956 auto result = scf::ForOp::create(
957 locB, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
958 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
959 Type stateType = loopState.empty() ? Type() : loopState[0].getType();
960
961 auto result = generateInBoundsCheck(
962 b, xferOp, iv, unpackedDim(xferOp),
963 stateType ? TypeRange(stateType) : TypeRange(),
964 /*inBoundsCase=*/
965 [&](OpBuilder &b, Location loc) {
966 // Create new transfer op.
967 OpTy newXfer = Strategy<OpTy>::rewriteOp(
968 b, this->options, xferOp, castedDataBuffer, iv, loopState);
969
970 // If old transfer op has a mask: Set mask on new transfer op.
971 // Special case: If the mask of the old transfer op is 1D and
972 // the unpacked dim is not a broadcast, no mask is needed on
973 // the new transfer op.
974 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
975 xferOp.getMaskType().getRank() > 1)) {
977 b.setInsertionPoint(newXfer); // Insert load before newXfer.
978
979 SmallVector<Value, 8> loadIndices;
980 getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
981 loadIndices, iv);
982 auto mask = memref::LoadOp::create(b, loc, castedMaskBuffer,
983 loadIndices);
984 rewriter.modifyOpInPlace(newXfer, [&]() {
985 newXfer.getMaskMutable().assign(mask);
986 });
987 }
988
989 return loopState.empty() ? Value() : newXfer->getResult(0);
990 },
991 /*outOfBoundsCase=*/
992 [&](OpBuilder &b, Location /*loc*/) {
993 return Strategy<OpTy>::handleOutOfBoundsDim(
994 b, xferOp, castedDataBuffer, iv, loopState);
995 });
996
997 maybeYieldValue(b, loc, !loopState.empty(), result);
998 });
999
1000 Strategy<OpTy>::cleanup(rewriter, xferOp, result);
1001 return success();
1002 }
1003};
1004
1005/// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
1006/// and ConstantMaskOp.
1007template <typename VscaleConstantBuilder>
1008static FailureOr<SmallVector<OpFoldResult>>
1009getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1010 if (!mask)
1012 if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
1013 return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
1014 return OpFoldResult(dimSize);
1015 });
1016 }
1017 if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
1018 int dimIdx = 0;
1019 VectorType maskType = constantMask.getVectorType();
1020 auto indexType = IndexType::get(mask.getContext());
1021 return llvm::map_to_vector(
1022 constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1023 // A scalable dim in a constant_mask means vscale x dimSize.
1024 if (maskType.getScalableDims()[dimIdx++])
1025 return OpFoldResult(createVscaleMultiple(dimSize));
1026 return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1027 });
1028 }
1029 return failure();
1030}
1031
1032/// Scalable vector lowering of transfer_write(transpose). This lowering only
1033/// supports rank 2 (scalable) vectors, but can be used in conjunction with
1034/// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
1035/// unrolls until the first scalable dimension.
1036///
1037/// Example:
1038///
1039/// BEFORE:
1040/// ```mlir
1041/// %transpose = vector.transpose %vec, [1, 0]
1042/// : vector<4x[4]xf32> to vector<[4]x4xf32>
1043/// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
1044/// : vector<[4]x4xf32>, memref<?x?xf32>
1045/// ```
1046///
1047/// AFTER:
1048/// ```mlir
1049/// %c1 = arith.constant 1 : index
1050/// %c4 = arith.constant 4 : index
1051/// %c0 = arith.constant 0 : index
1052/// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
1053/// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
1054/// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
1055/// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
1056/// %vscale = vector.vscale
1057/// %c4_vscale = arith.muli %vscale, %c4 : index
1058/// scf.for %idx = %c0 to %c4_vscale step %c1 {
1059/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
1060/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
1061/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
1062/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
1063/// %slice_i = affine.apply #map(%idx)[%i]
1064/// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
1065/// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
1066/// : vector<4xf32>, memref<?x?xf32>
1067/// }
1068/// ```
1069struct ScalableTransposeTransferWriteConversion
1070 : VectorToSCFPattern<vector::TransferWriteOp> {
1071 using VectorToSCFPattern::VectorToSCFPattern;
1072
1073 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1074 PatternRewriter &rewriter) const override {
1075 if (failed(checkLowerTensors(writeOp, rewriter)))
1076 return failure();
1077
1078 VectorType vectorType = writeOp.getVectorType();
1079
1080 // Note: By comparing the scalable dims to an ArrayRef of length two this
1081 // implicitly checks the rank (is also two).
1082 ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
1083 if (scalableFlags != ArrayRef<bool>{true, false}) {
1084 return rewriter.notifyMatchFailure(
1085 writeOp, "expected vector of the form vector<[N]xMxty>");
1086 }
1087
1088 auto permutationMap = writeOp.getPermutationMap();
1089 if (!permutationMap.isIdentity()) {
1090 return rewriter.notifyMatchFailure(
1091 writeOp, "non-identity permutations are unsupported (lower first)");
1092 }
1093
1094 // Note: This pattern is only lowering the leading dimension (to a loop),
1095 // so we only check if the leading dimension is in bounds. The in-bounds
1096 // attribute for the trailing dimension will be propagated.
1097 if (!writeOp.isDimInBounds(0)) {
1098 return rewriter.notifyMatchFailure(
1099 writeOp, "out-of-bounds dims are unsupported (use masking)");
1100 }
1101
1102 Value vector = writeOp.getVector();
1103 auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
1104 if (!transposeOp ||
1105 transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
1106 return rewriter.notifyMatchFailure(writeOp, "source not transpose");
1107 }
1108
1109 auto loc = writeOp.getLoc();
1110 auto createVscaleMultiple =
1111 vector::makeVscaleConstantBuilder(rewriter, loc);
1112
1113 auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1114 if (failed(maskDims)) {
1115 return rewriter.notifyMatchFailure(writeOp,
1116 "failed to resolve mask dims");
1117 }
1118
1119 int64_t fixedDimSize = vectorType.getDimSize(1);
1120 auto fixedDimOffsets = llvm::seq(fixedDimSize);
1121
1122 // Extract all slices from the source of the transpose.
1123 auto transposeSource = transposeOp.getVector();
1124 SmallVector<Value> transposeSourceSlices =
1125 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1126 return vector::ExtractOp::create(rewriter, loc, transposeSource, idx);
1127 });
1128
1129 // Loop bounds and step.
1130 auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0);
1131 auto ub =
1132 maskDims->empty()
1133 ? Value(createVscaleMultiple(vectorType.getDimSize(0)))
1134 : vector::getAsValues(rewriter, loc, maskDims->front()).front();
1135 auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
1136
1137 // Generate a new mask for the slice.
1138 VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
1139 Value sliceMask = nullptr;
1140 if (!maskDims->empty()) {
1141 sliceMask = vector::CreateMaskOp::create(
1142 rewriter, loc, sliceType.clone(rewriter.getI1Type()),
1143 ArrayRef<OpFoldResult>(*maskDims).drop_front());
1144 }
1145
1146 Value initDest = isTensorOp(writeOp) ? writeOp.getBase() : Value{};
1147 ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
1148 auto result = scf::ForOp::create(
1149 rewriter, loc, lb, ub, step, initLoopArgs,
1150 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
1151 // Indices for the new transfer op.
1152 SmallVector<Value, 8> xferIndices;
1153 getXferIndices(b, writeOp, iv, xferIndices);
1154
1155 // Extract a transposed slice from the source vector.
1156 SmallVector<Value> transposeElements =
1157 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1158 return vector::ExtractOp::create(
1159 b, loc, transposeSourceSlices[idx], iv);
1160 });
1161 auto sliceVec = vector::FromElementsOp::create(b, loc, sliceType,
1162 transposeElements);
1163
1164 // Create the transfer_write for the slice.
1165 Value dest =
1166 loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
1167 auto newWriteOp = vector::TransferWriteOp::create(
1168 b, loc, sliceVec, dest, xferIndices,
1169 ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
1170 if (sliceMask)
1171 newWriteOp.getMaskMutable().assign(sliceMask);
1172
1173 // Yield from the loop.
1174 scf::YieldOp::create(b, loc,
1175 loopIterArgs.empty() ? ValueRange{}
1176 : newWriteOp.getResult());
1177 });
1178
1179 if (isTensorOp(writeOp))
1180 rewriter.replaceOp(writeOp, result);
1181 else
1182 rewriter.eraseOp(writeOp);
1183
1184 return success();
1185 }
1186};
1187
1188} // namespace lowering_n_d
1189
1191
1192/// If the original transfer op has a mask, compute the mask of the new transfer
1193/// op (for the current iteration `i`) and assign it.
1194template <typename OpTy>
1195static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1196 int64_t i) {
1197 if (!xferOp.getMask())
1198 return;
1199
1200 if (xferOp.isBroadcastDim(0)) {
1201 // To-be-unpacked dimension is a broadcast, which does not have a
1202 // corresponding mask dimension. Mask attribute remains unchanged.
1203 newXferOp.getMaskMutable().assign(xferOp.getMask());
1204 return;
1205 }
1206
1207 if (xferOp.getMaskType().getRank() > 1) {
1208 // Unpack one dimension of the mask.
1210 b.setInsertionPoint(newXferOp); // Insert load before newXfer.
1211
1213 Location loc = xferOp.getLoc();
1214 auto newMask = vector::ExtractOp::create(b, loc, xferOp.getMask(), indices);
1215 newXferOp.getMaskMutable().assign(newMask);
1216 }
1217
1218 // If we end up here: The mask of the old transfer op is 1D and the unpacked
1219 // dim is not a broadcast, so no mask is needed on the new transfer op.
1220 // `generateInBoundsCheck` will have evaluated the mask already.
1221}
1222
1223/// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
1224/// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
1225/// memref buffer is allocated and the SCF loop is fully unrolled.
1226///
1227/// ```
1228/// E.g.:
1229/// ```
1230/// %vec = vector.transfer_read %A[%a, %b, %c], %padding
1231/// : memref<?x?x?xf32>, vector<5x4xf32>
1232/// ```
1233/// is rewritten to IR such as (simplified):
1234/// ```
1235/// %v_init = splat %padding : vector<5x4xf32>
1236/// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
1237/// : memref<?x?x?xf32>, vector<4xf32>
1238/// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
1239/// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
1240/// : memref<?x?x?xf32>, vector<4xf32>
1241/// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
1242/// ...
1243/// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
1244/// : memref<?x?x?xf32>, vector<4xf32>
1245/// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
1246/// ```
1247///
1248/// Note: As an optimization, if the result of the original TransferReadOp
1249/// was directly inserted into another vector, no new %v_init vector is created.
1250/// Instead, the new TransferReadOp results are inserted into that vector.
1251struct UnrollTransferReadConversion
1252 : public VectorToSCFPattern<TransferReadOp> {
1253 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1254
1255 void initialize() {
1256 // This pattern recursively unpacks one dimension at a time. The recursion
1257 // bounded as the rank is strictly decreasing.
1258 setHasBoundedRewriteRecursion();
1259 }
1260
1261 /// Get or build the vector into which the newly created TransferReadOp
1262 /// results are inserted.
1263 Value buildResultVector(PatternRewriter &rewriter,
1264 TransferReadOp xferOp) const {
1265 if (auto insertOp = getInsertOp(xferOp))
1266 return insertOp.getDest();
1267 Location loc = xferOp.getLoc();
1268 return vector::BroadcastOp::create(rewriter, loc, xferOp.getVectorType(),
1269 xferOp.getPadding());
1270 }
1271
1272 /// If the result of the TransferReadOp has exactly one user, which is a
1273 /// vector::InsertOp, return that operation.
1274 vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
1275 if (xferOp->hasOneUse()) {
1276 Operation *xferOpUser = *xferOp->getUsers().begin();
1277 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1278 return insertOp;
1279 }
1280
1281 return vector::InsertOp();
1282 }
1283
1284 /// If the result of the TransferReadOp has exactly one user, which is a
1285 /// vector::InsertOp, return that operation's indices.
1286 void getInsertionIndices(TransferReadOp xferOp,
1288 if (auto insertOp = getInsertOp(xferOp)) {
1289 auto pos = insertOp.getMixedPosition();
1290 indices.append(pos.begin(), pos.end());
1291 }
1292 }
1293
1294 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1295 /// accesses, and broadcasts and transposes in permutation maps.
1296 LogicalResult matchAndRewrite(TransferReadOp xferOp,
1297 PatternRewriter &rewriter) const override {
1298 if (xferOp.getVectorType().getRank() <= options.targetRank)
1299 return rewriter.notifyMatchFailure(
1300 xferOp, "vector rank is less or equal to target rank");
1301 if (failed(checkLowerTensors(xferOp, rewriter)))
1302 return failure();
1303 if (xferOp.getVectorType().getElementType() !=
1304 xferOp.getShapedType().getElementType())
1305 return rewriter.notifyMatchFailure(
1306 xferOp, "not yet supported: element type mismatch");
1307 auto xferVecType = xferOp.getVectorType();
1308 if (xferVecType.getScalableDims()[0]) {
1309 return rewriter.notifyMatchFailure(
1310 xferOp, "scalable dimensions cannot be unrolled at compile time");
1311 }
1312
1313 auto insertOp = getInsertOp(xferOp);
1314 auto vec = buildResultVector(rewriter, xferOp);
1315 auto vecType = dyn_cast<VectorType>(vec.getType());
1316
1317 VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
1318
1319 int64_t dimSize = xferVecType.getShape()[0];
1320
1321 // Generate fully unrolled loop of transfer ops.
1322 Location loc = xferOp.getLoc();
1323 for (int64_t i = 0; i < dimSize; ++i) {
1324 Value iv = arith::ConstantIndexOp::create(rewriter, loc, i);
1325
1326 // FIXME: Rename this lambda - it does much more than just
1327 // in-bounds-check generation.
1328 vec = generateInBoundsCheck(
1329 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
1330 /*inBoundsCase=*/
1331 [&](OpBuilder &b, Location loc) {
1332 // Indices for the new transfer op.
1333 SmallVector<Value, 8> xferIndices;
1334 getXferIndices(b, xferOp, iv, xferIndices);
1335
1336 // Indices for the new vector.insert op.
1337 SmallVector<OpFoldResult, 8> insertionIndices;
1338 getInsertionIndices(xferOp, insertionIndices);
1339 insertionIndices.push_back(rewriter.getIndexAttr(i));
1340
1341 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1342
1343 auto newXferOp = vector::TransferReadOp::create(
1344 b, loc, newXferVecType, xferOp.getBase(), xferIndices,
1345 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1346 xferOp.getPadding(), Value(), inBoundsAttr);
1347 maybeAssignMask(b, xferOp, newXferOp, i);
1348
1349 Value valToInser = newXferOp.getResult();
1350 if (newXferVecType.getRank() == 0) {
1351 // vector.insert does not accept rank-0 as the non-indexed
1352 // argument. Extract the scalar before inserting.
1353 valToInser = vector::ExtractOp::create(b, loc, valToInser,
1355 }
1356 return vector::InsertOp::create(b, loc, valToInser, vec,
1357 insertionIndices);
1358 },
1359 /*outOfBoundsCase=*/
1360 [&](OpBuilder &b, Location loc) {
1361 // Loop through original (unmodified) vector.
1362 return vec;
1363 });
1364 }
1365
1366 if (insertOp) {
1367 // Rewrite single user of the old TransferReadOp, which was an InsertOp.
1368 rewriter.replaceOp(insertOp, vec);
1369 rewriter.eraseOp(xferOp);
1370 } else {
1371 rewriter.replaceOp(xferOp, vec);
1372 }
1373
1374 return success();
1375 }
1376};
1377
1378/// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
1379/// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
1380/// memref buffer is allocated and the SCF loop is fully unrolled.
1381///
1382/// ```
1383/// E.g.:
1384/// ```
1385/// vector.transfer_write %vec, %A[%a, %b, %c]
1386/// : vector<5x4xf32>, memref<?x?x?xf32>
1387/// ```
1388/// is rewritten to IR such as (simplified):
1389/// ```
1390/// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32>
1391/// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
1392/// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32>
1393/// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
1394/// ...
1395/// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32>
1396/// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
1397/// ```
1398///
1399/// Note: As an optimization, if the vector of the original TransferWriteOp
1400/// was directly extracted from another vector via an ExtractOp `a`, extract
1401/// the vectors for the newly generated TransferWriteOps from `a`'s input. By
1402/// doing so, `a` may become dead, and the number of ExtractOps generated during
1403/// recursive application of this pattern will be minimal.
1404struct UnrollTransferWriteConversion
1405 : public VectorToSCFPattern<TransferWriteOp> {
1406 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1407
1408 void initialize() {
1409 // This pattern recursively unpacks one dimension at a time. The recursion
1410 // bounded as the rank is strictly decreasing.
1411 setHasBoundedRewriteRecursion();
1412 }
1413
1414 /// Return the vector from which newly generated ExtracOps will extract.
1415 Value getDataVector(TransferWriteOp xferOp) const {
1416 if (auto extractOp = getExtractOp(xferOp))
1417 return extractOp.getSource();
1418 return xferOp.getVector();
1419 }
1420
1421 /// If the input of the given TransferWriteOp is an ExtractOp, return it.
1422 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
1423 if (auto *op = xferOp.getVector().getDefiningOp())
1424 return dyn_cast<vector::ExtractOp>(op);
1425 return vector::ExtractOp();
1426 }
1427
1428 /// If the input of the given TransferWriteOp is an ExtractOp, return its
1429 /// indices.
1430 void getExtractionIndices(TransferWriteOp xferOp,
1432 if (auto extractOp = getExtractOp(xferOp)) {
1433 auto pos = extractOp.getMixedPosition();
1434 indices.append(pos.begin(), pos.end());
1435 }
1436 }
1437
1438 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1439 /// accesses, and broadcasts and transposes in permutation maps.
1440 LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1441 PatternRewriter &rewriter) const override {
1442 VectorType inputVectorTy = xferOp.getVectorType();
1443
1444 if (inputVectorTy.getRank() <= options.targetRank)
1445 return failure();
1446
1447 if (failed(checkLowerTensors(xferOp, rewriter)))
1448 return failure();
1449 // Transfer ops that modify the element type are not supported atm.
1450 if (inputVectorTy.getElementType() !=
1451 xferOp.getShapedType().getElementType())
1452 return failure();
1453
1454 auto vec = getDataVector(xferOp);
1455 if (inputVectorTy.getScalableDims()[0]) {
1456 // Cannot unroll a scalable dimension at compile time.
1457 return failure();
1458 }
1459
1460 int64_t dimSize = inputVectorTy.getShape()[0];
1461 Value source = xferOp.getBase(); // memref or tensor to be written to.
1462 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1463
1464 // Generate fully unrolled loop of transfer ops.
1465 Location loc = xferOp.getLoc();
1466 for (int64_t i = 0; i < dimSize; ++i) {
1467 Value iv = arith::ConstantIndexOp::create(rewriter, loc, i);
1468
1469 auto updatedSource = generateInBoundsCheck(
1470 rewriter, xferOp, iv, unpackedDim(xferOp),
1471 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1472 /*inBoundsCase=*/
1473 [&](OpBuilder &b, Location loc) {
1474 // Indices for the new transfer op.
1475 SmallVector<Value, 8> xferIndices;
1476 getXferIndices(b, xferOp, iv, xferIndices);
1477
1478 // Indices for the new vector.extract op.
1479 SmallVector<OpFoldResult, 8> extractionIndices;
1480 getExtractionIndices(xferOp, extractionIndices);
1481 extractionIndices.push_back(b.getI64IntegerAttr(i));
1482
1483 auto extracted =
1484 vector::ExtractOp::create(b, loc, vec, extractionIndices);
1485 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1486 Value xferVec;
1487 if (inputVectorTy.getRank() == 1) {
1488 // When target-rank=0, unrolling would causes the vector input
1489 // argument into `transfer_write` to become a scalar. We solve
1490 // this by broadcasting the scalar to a 0D vector.
1491 xferVec = vector::BroadcastOp::create(
1492 b, loc, VectorType::get({}, extracted.getType()), extracted);
1493 } else {
1494 xferVec = extracted;
1495 }
1496 auto newXferOp = vector::TransferWriteOp::create(
1497 b, loc, sourceType, xferVec, source, xferIndices,
1498 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1499 inBoundsAttr);
1500
1501 maybeAssignMask(b, xferOp, newXferOp, i);
1502
1503 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1504 },
1505 /*outOfBoundsCase=*/
1506 [&](OpBuilder &b, Location loc) {
1507 return isTensorOp(xferOp) ? source : Value();
1508 });
1509
1510 if (isTensorOp(xferOp))
1511 source = updatedSource;
1512 }
1513
1514 if (isTensorOp(xferOp))
1515 rewriter.replaceOp(xferOp, source);
1516 else
1517 rewriter.eraseOp(xferOp);
1518
1519 return success();
1520 }
1521};
1522
1523} // namespace lowering_n_d_unrolled
1524
1525namespace lowering_1_d {
1526
1527/// Compute the indices into the memref for the LoadOp/StoreOp generated as
1528/// part of TransferOp1dConversion. Return the memref dimension on which
1529/// the transfer is operating. A return value of std::nullopt indicates a
1530/// broadcast.
1531template <typename OpTy>
1532static std::optional<int64_t>
1533get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1534 SmallVector<Value, 8> &memrefIndices) {
1535 auto indices = xferOp.getIndices();
1536 auto map = xferOp.getPermutationMap();
1537 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1538
1539 memrefIndices.append(indices.begin(), indices.end());
1540 assert(map.getNumResults() == 1 &&
1541 "Expected 1 permutation map result for 1D transfer");
1542 if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1543 Location loc = xferOp.getLoc();
1544 auto dim = expr.getPosition();
1545 AffineExpr d0, d1;
1546 bindDims(xferOp.getContext(), d0, d1);
1547 Value offset = memrefIndices[dim];
1548 memrefIndices[dim] =
1549 affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1550 return dim;
1551 }
1552
1553 assert(xferOp.isBroadcastDim(0) &&
1554 "Expected AffineDimExpr or AffineConstantExpr");
1555 return std::nullopt;
1556}
1557
1558/// Codegen strategy for TransferOp1dConversion, depending on the
1559/// operation.
1560template <typename OpTy>
1561struct Strategy1d;
1562
1563/// Codegen strategy for TransferReadOp.
1564template <>
1565struct Strategy1d<TransferReadOp> {
1566 static void generateForLoopBody(OpBuilder &b, Location loc,
1567 TransferReadOp xferOp, Value iv,
1568 ValueRange loopState) {
1570 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1571 auto vec = loopState[0];
1572
1573 // In case of out-of-bounds access, leave `vec` as is (was initialized with
1574 // padding value).
1575 auto nextVec = generateInBoundsCheck(
1576 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1577 /*inBoundsCase=*/
1578 [&](OpBuilder &b, Location loc) {
1579 Value val = memref::LoadOp::create(b, loc, xferOp.getBase(), indices);
1580 return vector::InsertOp::create(b, loc, val, vec, iv);
1581 },
1582 /*outOfBoundsCase=*/
1583 [&](OpBuilder & /*b*/, Location loc) { return vec; });
1584 scf::YieldOp::create(b, loc, nextVec);
1585 }
1586
1587 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1588 // Inititalize vector with padding value.
1589 Location loc = xferOp.getLoc();
1590 return vector::BroadcastOp::create(b, loc, xferOp.getVectorType(),
1591 xferOp.getPadding());
1592 }
1593};
1594
1595/// Codegen strategy for TransferWriteOp.
1596template <>
1597struct Strategy1d<TransferWriteOp> {
1598 static void generateForLoopBody(OpBuilder &b, Location loc,
1599 TransferWriteOp xferOp, Value iv,
1600 ValueRange /*loopState*/) {
1602 auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1603
1604 // Nothing to do in case of out-of-bounds access.
1605 generateInBoundsCheck(
1606 b, xferOp, iv, dim,
1607 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1608 auto val = vector::ExtractOp::create(b, loc, xferOp.getVector(), iv);
1609 memref::StoreOp::create(b, loc, val, xferOp.getBase(), indices);
1610 });
1611 scf::YieldOp::create(b, loc);
1612 }
1613
1614 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1615 return Value();
1616 }
1617};
1618
1619/// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1620/// necessary in cases where a 1D vector transfer op cannot be lowered into
1621/// vector load/stores due to non-unit strides or broadcasts:
1622///
1623/// * Transfer dimension is not the last memref dimension
1624/// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1625/// * Memref has a layout map with non-unit stride on the last dimension
1626///
1627/// This pattern generates IR as follows:
1628///
1629/// 1. Generate a for loop iterating over each vector element.
1630/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1631/// depending on OpTy.
1632///
1633/// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1634/// can be generated instead of TransferOp1dConversion. Add such a pattern
1635/// to ConvertVectorToLLVM.
1636///
1637/// E.g.:
1638/// ```
1639/// vector.transfer_write %vec, %A[%a, %b]
1640/// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1641/// : vector<9xf32>, memref<?x?xf32>
1642/// ```
1643/// Is rewritten to approximately the following pseudo-IR:
1644/// ```
1645/// for i = 0 to 9 {
1646/// %t = vector.extract %vec[i] : f32 from vector<9xf32>
1647/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1648/// }
1649/// ```
1650template <typename OpTy>
1651struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1652 using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1653
1654 LogicalResult matchAndRewrite(OpTy xferOp,
1655 PatternRewriter &rewriter) const override {
1656 // TODO: support 0-d corner case.
1657 if (xferOp.getTransferRank() == 0)
1658 return failure();
1659 auto map = xferOp.getPermutationMap();
1660 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1661
1662 if (!memRefType)
1663 return failure();
1664 if (xferOp.getVectorType().getRank() != 1)
1665 return failure();
1666 if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
1667 return failure(); // Handled by ConvertVectorToLLVM
1668
1669 // Loop bounds, step, state...
1670 Location loc = xferOp.getLoc();
1671 auto vecType = xferOp.getVectorType();
1672 auto lb = arith::ConstantIndexOp::create(rewriter, loc, 0);
1673 Value ub =
1674 arith::ConstantIndexOp::create(rewriter, loc, vecType.getDimSize(0));
1675 if (vecType.isScalable()) {
1676 Value vscale =
1677 vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
1678 ub = arith::MulIOp::create(rewriter, loc, ub, vscale);
1679 }
1680 auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
1681 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1682
1683 // Generate for loop.
1684 rewriter.replaceOpWithNewOp<scf::ForOp>(
1685 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1686 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1687 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1688 });
1689
1690 return success();
1691 }
1692};
1693
1694} // namespace lowering_1_d
1695} // namespace
1696
1699 if (options.unroll) {
1700 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1701 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1702 patterns.getContext(), options);
1703 } else {
1704 patterns.add<lowering_n_d::PrepareTransferReadConversion,
1705 lowering_n_d::PrepareTransferWriteConversion,
1706 lowering_n_d::TransferOpConversion<TransferReadOp>,
1707 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1708 patterns.getContext(), options);
1709 }
1710 if (options.lowerScalable) {
1711 patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1712 patterns.getContext(), options);
1713 }
1714 if (options.targetRank == 1) {
1715 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1716 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1717 patterns.getContext(), options);
1718 }
1719 patterns.add<lowering_n_d::DecomposePrintOpConversion>(patterns.getContext(),
1720 options);
1721}
1722
1723namespace {
1724
1725struct ConvertVectorToSCFPass
1726 : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1727 ConvertVectorToSCFPass() = default;
1728 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1729 this->fullUnroll = options.unroll;
1730 this->targetRank = options.targetRank;
1731 this->lowerTensors = options.lowerTensors;
1732 this->lowerScalable = options.lowerScalable;
1733 }
1734
1735 void runOnOperation() override {
1736 VectorTransferToSCFOptions options;
1737 options.unroll = fullUnroll;
1738 options.targetRank = targetRank;
1739 options.lowerTensors = lowerTensors;
1740 options.lowerScalable = lowerScalable;
1741
1742 // Lower permutation maps first.
1743 RewritePatternSet lowerTransferPatterns(&getContext());
1745 lowerTransferPatterns);
1746 (void)applyPatternsGreedily(getOperation(),
1747 std::move(lowerTransferPatterns));
1748
1749 RewritePatternSet patterns(&getContext());
1751 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
1752 }
1753};
1754
1755} // namespace
1756
1757std::unique_ptr<Pass>
1759 return std::make_unique<ConvertVectorToSCFPass>(options);
1760}
return success()
MLIR_CRUNNERUTILS_EXPORT void printClose()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
static void printOp(llvm::raw_ostream &os, Operation *op, OpPrintingFlags &flags)
Definition Unit.cpp:18
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
UnitAttr getUnitAttr()
Definition Builders.cpp:98
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
A trait of region holding operations that define a new scope for automatic allocations,...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
OperandRange operand_range
Definition Operation.h:371
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Block & front()
Definition Region.h:65
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
user_range getUsers() const
Definition Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void populateVectorToSCFConversionPatterns(RewritePatternSet &patterns, const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Collect a set of patterns to convert from the Vector dialect to SCF + func.
std::unique_ptr< Pass > createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options=VectorTransferToSCFOptions())
Create a pass to convert a subset of vector ops to SCF.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...
Definition VectorToSCF.h:52