MLIR 22.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1//===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/Attributes.h"
22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24#include "llvm/Support/FormatVariadic.h"
25#include <utility>
26
27using namespace mlir;
28using namespace mlir::vector;
29using namespace mlir::gpu;
30
31/// Currently the distribution map is implicit based on the vector shape. In the
32/// future it will be part of the op.
33/// Example:
34/// ```
35/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
36/// ...
37/// gpu.yield %3 : vector<32x16x64xf32>
38/// }
39/// ```
40/// Would have an implicit map of:
41/// `(d0, d1, d2) -> (d0, d2)`
42static AffineMap calculateImplicitMap(VectorType sequentialType,
43 VectorType distributedType) {
45 perm.reserve(1);
46 // Check which dimensions of the sequential type are different than the
47 // dimensions of the distributed type to know the distributed dimensions. Then
48 // associate each distributed dimension to an ID in order.
49 for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
51 perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
52 }
53 auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
54 distributedType.getContext());
55 return map;
56}
57
58/// Given a sequential and distributed vector type, returns the distributed
59/// dimension. This function expects that only a single dimension is
60/// distributed.
61static int getDistributedDim(VectorType sequentialType,
62 VectorType distributedType) {
63 assert(sequentialType.getRank() == distributedType.getRank() &&
64 "sequential and distributed vector types must have the same rank");
65 int64_t distributedDim = -1;
66 for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
67 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
68 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
69 // support distributing multiple dimensions in the future.
70 assert(distributedDim == -1 && "found multiple distributed dims");
71 distributedDim = i;
72 }
73 }
74 return distributedDim;
75}
76
77namespace {
78
79/// Helper struct to create the load / store operations that permit transit
80/// through the parallel / sequential and the sequential / parallel boundaries
81/// when performing `rewriteWarpOpToScfFor`.
82///
83/// The vector distribution dimension is inferred from the vector types.
84struct DistributedLoadStoreHelper {
85 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
86 Value laneId, Value zero)
87 : sequentialVal(sequentialVal), distributedVal(distributedVal),
88 laneId(laneId), zero(zero) {
89 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
90 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
91 if (sequentialVectorType && distributedVectorType)
92 distributionMap =
93 calculateImplicitMap(sequentialVectorType, distributedVectorType);
94 }
95
96 Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
97 int64_t distributedSize = distributedVectorType.getDimSize(index);
98 AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
99 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
100 ArrayRef<Value>{laneId});
101 }
102
103 /// Create a store during the process of distributing the
104 /// `vector.warp_execute_on_thread_0` op.
105 /// Vector distribution assumes the following convention regarding the
106 /// temporary buffers that are created to transition values. This **must**
107 /// be properly specified in the `options.warpAllocationFn`:
108 /// 1. scalars of type T transit through a memref<1xT>.
109 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
110 Operation *buildStore(RewriterBase &b, Location loc, Value val,
111 Value buffer) {
112 assert((val == distributedVal || val == sequentialVal) &&
113 "Must store either the preregistered distributed or the "
114 "preregistered sequential value.");
115 // Scalar case can directly use memref.store.
116 if (!isa<VectorType>(val.getType()))
117 return memref::StoreOp::create(b, loc, val, buffer, zero);
118
119 // Vector case must use vector::TransferWriteOp which will later lower to
120 // vector.store of memref.store depending on further lowerings.
121 int64_t rank = sequentialVectorType.getRank();
122 SmallVector<Value> indices(rank, zero);
123 if (val == distributedVal) {
124 for (auto dimExpr : distributionMap.getResults()) {
125 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126 indices[index] = buildDistributedOffset(b, loc, index);
127 }
128 }
129 SmallVector<bool> inBounds(indices.size(), true);
130 return vector::TransferWriteOp::create(
131 b, loc, val, buffer, indices,
132 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
133 }
134
135 /// Create a load during the process of distributing the
136 /// `vector.warp_execute_on_thread_0` op.
137 /// Vector distribution assumes the following convention regarding the
138 /// temporary buffers that are created to transition values. This **must**
139 /// be properly specified in the `options.warpAllocationFn`:
140 /// 1. scalars of type T transit through a memref<1xT>.
141 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
142 ///
143 /// When broadcastMode is true, the load is not distributed to account for
144 /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
145 ///
146 /// Example:
147 ///
148 /// ```
149 /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
150 /// gpu.yield %cst : f32
151 /// }
152 /// // Both types are f32. The constant %cst is broadcasted to all lanes.
153 /// ```
154 /// This behavior described in more detail in the documentation of the op.
155 Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
156
157 // Scalar case can directly use memref.store.
158 if (!isa<VectorType>(type))
159 return memref::LoadOp::create(b, loc, buffer, zero);
160
161 // Other cases must be vector atm.
162 // Vector case must use vector::TransferReadOp which will later lower to
163 // vector.read of memref.read depending on further lowerings.
164 assert((type == distributedVectorType || type == sequentialVectorType) &&
165 "Must store either the preregistered distributed or the "
166 "preregistered sequential type.");
167 SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
168 if (type == distributedVectorType) {
169 for (auto dimExpr : distributionMap.getResults()) {
170 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171 indices[index] = buildDistributedOffset(b, loc, index);
172 }
173 }
174 SmallVector<bool> inBounds(indices.size(), true);
175 return vector::TransferReadOp::create(
176 b, loc, cast<VectorType>(type), buffer, indices,
177 /*padding=*/std::nullopt,
178 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
179 }
180
181 Value sequentialVal, distributedVal, laneId, zero;
182 VectorType sequentialVectorType, distributedVectorType;
183 AffineMap distributionMap;
184};
185
186} // namespace
187
188// Clones `op` into a new operation that takes `operands` and returns
189// `resultTypes`.
191 Location loc, Operation *op,
192 ArrayRef<Value> operands,
193 ArrayRef<Type> resultTypes) {
194 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
195 op->getAttrs());
196 return rewriter.create(res);
197}
198
199namespace {
200
201/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
202/// thread `laneId` executes the entirety of the computation.
203///
204/// After the transformation:
205/// - the IR within the scf.if op can be thought of as executing sequentially
206/// (from the point of view of threads along `laneId`).
207/// - the IR outside of the scf.if op can be thought of as executing in
208/// parallel (from the point of view of threads along `laneId`).
209///
210/// Values that need to transit through the parallel / sequential and the
211/// sequential / parallel boundaries do so via reads and writes to a temporary
212/// memory location.
213///
214/// The transformation proceeds in multiple steps:
215/// 1. Create the scf.if op.
216/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
217/// within the scf.if to transit the values captured from above.
218/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
219/// consistent within the scf.if.
220/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
221/// 5. Insert appropriate writes within scf.if and reads after the scf.if to
222/// transit the values returned by the op.
223/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
224/// consistent after the scf.if.
225/// 7. Perform late cleanups.
226///
227/// All this assumes the vector distribution occurs along the most minor
228/// distributed vector dimension.
229struct WarpOpToScfIfPattern : public WarpDistributionPattern {
230 WarpOpToScfIfPattern(MLIRContext *context,
231 const WarpExecuteOnLane0LoweringOptions &options,
232 PatternBenefit benefit = 1)
233 : WarpDistributionPattern(context, benefit), options(options) {}
234
235 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
236 PatternRewriter &rewriter) const override {
237 assert(warpOp.getBodyRegion().hasOneBlock() &&
238 "expected WarpOp with single block");
239 Block *warpOpBody = &warpOp.getBodyRegion().front();
240 Location loc = warpOp.getLoc();
241
242 // Passed all checks. Start rewriting.
243 OpBuilder::InsertionGuard g(rewriter);
244 rewriter.setInsertionPoint(warpOp);
245
246 // Step 1: Create scf.if op.
247 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
248 Value isLane0 = arith::CmpIOp::create(
249 rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
250 auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
251 /*withElseRegion=*/false);
252 rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
253
254 // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
255 // reads within the scf.if to transit the values captured from above.
256 SmallVector<Value> bbArgReplacements;
257 for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
258 Value sequentialVal = warpOpBody->getArgument(it.index());
259 Value distributedVal = it.value();
260 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
261 warpOp.getLaneid(), c0);
262
263 // Create buffer before the ifOp.
264 rewriter.setInsertionPoint(ifOp);
265 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
266 sequentialVal.getType());
267 // Store distributed vector into buffer, before the ifOp.
268 helper.buildStore(rewriter, loc, distributedVal, buffer);
269 // Load sequential vector from buffer, inside the ifOp.
270 rewriter.setInsertionPointToStart(ifOp.thenBlock());
271 bbArgReplacements.push_back(
272 helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
273 }
274
275 // Step 3. Insert sync after all the stores and before all the loads.
276 if (!warpOp.getArgs().empty()) {
277 rewriter.setInsertionPoint(ifOp);
278 options.warpSyncronizationFn(loc, rewriter, warpOp);
279 }
280
281 // Step 4. Move body of warpOp to ifOp.
282 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
283
284 // Step 5. Insert appropriate writes within scf.if and reads after the
285 // scf.if to transit the values returned by the op.
286 // TODO: at this point, we can reuse the shared memory from previous
287 // buffers.
288 SmallVector<Value> replacements;
289 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
290 Location yieldLoc = yieldOp.getLoc();
291 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
292 Value sequentialVal = it.value();
293 Value distributedVal = warpOp->getResult(it.index());
294 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
295 warpOp.getLaneid(), c0);
296
297 // Create buffer before the ifOp.
298 rewriter.setInsertionPoint(ifOp);
299 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
300 sequentialVal.getType());
301
302 // Store yielded value into buffer, inside the ifOp, before the
303 // terminator.
304 rewriter.setInsertionPoint(yieldOp);
305 helper.buildStore(rewriter, loc, sequentialVal, buffer);
306
307 // Load distributed value from buffer, after the warpOp.
308 rewriter.setInsertionPointAfter(ifOp);
309 // Result type and yielded value type are the same. This is a broadcast.
310 // E.g.:
311 // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
312 // gpu.yield %cst : f32
313 // }
314 // Both types are f32. The constant %cst is broadcasted to all lanes.
315 // This is described in more detail in the documentation of the op.
316 replacements.push_back(
317 helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
318 }
319
320 // Step 6. Insert sync after all the stores and before all the loads.
321 if (!yieldOp.getOperands().empty()) {
322 rewriter.setInsertionPointAfter(ifOp);
323 options.warpSyncronizationFn(loc, rewriter, warpOp);
324 }
325
326 // Step 7. Delete terminator and add empty scf.yield.
327 rewriter.eraseOp(yieldOp);
328 rewriter.setInsertionPointToEnd(ifOp.thenBlock());
329 scf::YieldOp::create(rewriter, yieldLoc);
330
331 // Compute replacements for WarpOp results.
332 rewriter.replaceOp(warpOp, replacements);
333
334 return success();
335 }
336
337private:
338 const WarpExecuteOnLane0LoweringOptions &options;
339};
340
341/// Return the distributed vector type based on the original type and the
342/// distribution map. The map is expected to have a dimension equal to the
343/// original type rank and should be a projection where the results are the
344/// distributed dimensions. If the number of results is zero there is no
345/// distribution (i.e. original type is returned).
346/// Otherwise, The number of results should be equal to the number
347/// of warp sizes which is currently limited to 1.
348/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
349/// and a warp size of 16 would distribute the second dimension (associated to
350/// d1) and return vector<16x2x64>
351static VectorType getDistributedType(VectorType originalType, AffineMap map,
352 int64_t warpSize) {
353 // If the map has zero results, return the original type.
354 if (map.getNumResults() == 0)
355 return originalType;
356 SmallVector<int64_t> targetShape(originalType.getShape());
357 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
358 unsigned position = map.getDimPosition(i);
359 if (targetShape[position] % warpSize != 0) {
360 if (warpSize % targetShape[position] != 0) {
361 return VectorType();
362 }
363 warpSize /= targetShape[position];
364 targetShape[position] = 1;
365 continue;
366 }
367 targetShape[position] = targetShape[position] / warpSize;
368 warpSize = 1;
369 break;
370 }
371 if (warpSize != 1) {
372 return VectorType();
373 }
374 VectorType targetType =
375 VectorType::get(targetShape, originalType.getElementType());
376 return targetType;
377}
378
379/// Given a warpOp that contains ops with regions, the corresponding op's
380/// "inner" region and the distributionMapFn, get all values used by the op's
381/// region that are defined within the warpOp, but outside the inner region.
382/// Return the set of values, their types and their distributed types.
383std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
385getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
386 DistributionMapFn distributionMapFn) {
387 llvm::SmallSetVector<Value, 32> escapingValues;
388 SmallVector<Type> escapingValueTypes;
389 SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
390 if (innerRegion.empty())
391 return {std::move(escapingValues), std::move(escapingValueTypes),
392 std::move(escapingValueDistTypes)};
393 mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
394 Operation *parent = operand->get().getParentRegion()->getParentOp();
395 if (warpOp->isAncestor(parent)) {
396 if (!escapingValues.insert(operand->get()))
397 return;
398 Type distType = operand->get().getType();
399 if (auto vecType = dyn_cast<VectorType>(distType)) {
400 AffineMap map = distributionMapFn(operand->get());
401 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
402 }
403 escapingValueTypes.push_back(operand->get().getType());
404 escapingValueDistTypes.push_back(distType);
405 }
406 });
407 return {std::move(escapingValues), std::move(escapingValueTypes),
408 std::move(escapingValueDistTypes)};
409}
410
411/// Distribute transfer_write ops based on the affine map returned by
412/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
413/// will not be distributed (it should be less than the warp size).
414///
415/// Example:
416/// ```
417/// %0 = gpu.warp_execute_on_lane_0(%id){
418/// ...
419/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
420/// gpu.yield
421/// }
422/// ```
423/// To
424/// ```
425/// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
426/// ...
427/// gpu.yield %v : vector<32xf32>
428/// }
429/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
430struct WarpOpTransferWrite : public WarpDistributionPattern {
431 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
432 unsigned maxNumElementsToExtract, PatternBenefit b = 1)
433 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
434 maxNumElementsToExtract(maxNumElementsToExtract) {}
435
436 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
437 /// are multiples of the distribution ratio are supported at the moment.
438 LogicalResult tryDistributeOp(RewriterBase &rewriter,
439 vector::TransferWriteOp writeOp,
440 WarpExecuteOnLane0Op warpOp) const {
441 VectorType writtenVectorType = writeOp.getVectorType();
442
443 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
444 // to separate it from the rest.
445 if (writtenVectorType.getRank() == 0)
446 return failure();
447
448 // 2. Compute the distributed type.
449 AffineMap map = distributionMapFn(writeOp.getVector());
450 VectorType targetType =
451 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
452 if (!targetType)
453 return failure();
454
455 // 2.5 Compute the distributed type for the new mask;
456 VectorType maskType;
457 if (writeOp.getMask()) {
458 // TODO: Distribution of masked writes with non-trivial permutation maps
459 // requires the distribution of the mask to elementwise match the
460 // distribution of the permuted written vector. Currently the details
461 // of which lane is responsible for which element is captured strictly
462 // by shape information on the warp op, and thus requires materializing
463 // the permutation in IR.
464 if (!writeOp.getPermutationMap().isMinorIdentity())
465 return failure();
466 maskType =
467 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
468 }
469
470 // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
471 // the rest.
472 vector::TransferWriteOp newWriteOp =
473 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
474
475 // 4. Reindex the write using the distribution map.
476 auto newWarpOp =
477 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
478
479 // Delinearize the lane id based on the way threads are divided across the
480 // vector. To get the number of threads per vector dimension, divide the
481 // sequential size by the distributed size along each dim.
482 rewriter.setInsertionPoint(newWriteOp);
483 SmallVector<OpFoldResult> delinearizedIdSizes;
484 for (auto [seqSize, distSize] :
485 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
486 assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
487 delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
488 }
489 SmallVector<Value> delinearized;
490 if (map.getNumResults() > 1) {
491 delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
492 rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
493 delinearizedIdSizes)
494 .getResults();
495 } else {
496 // If there is only one map result, we can elide the delinearization
497 // op and use the lane id directly.
498 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
499 }
500
501 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
502 Location loc = newWriteOp.getLoc();
503 SmallVector<Value> indices(newWriteOp.getIndices().begin(),
504 newWriteOp.getIndices().end());
505 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
506 AffineExpr d0, d1;
507 bindDims(newWarpOp.getContext(), d0, d1);
508 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
509 if (!indexExpr)
510 continue;
511 unsigned indexPos = indexExpr.getPosition();
512 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
513 Value laneId = delinearized[vectorPos];
514 auto scale =
515 rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
517 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
518 }
519 newWriteOp.getIndicesMutable().assign(indices);
520
521 return success();
522 }
523
524 /// Extract TransferWriteOps of vector<1x> into a separate warp op.
525 LogicalResult tryExtractOp(RewriterBase &rewriter,
526 vector::TransferWriteOp writeOp,
527 WarpExecuteOnLane0Op warpOp) const {
528 Location loc = writeOp.getLoc();
529 VectorType vecType = writeOp.getVectorType();
530
531 if (vecType.getNumElements() > maxNumElementsToExtract) {
532 return rewriter.notifyMatchFailure(
533 warpOp,
534 llvm::formatv(
535 "writes more elements ({0}) than allowed to extract ({1})",
536 vecType.getNumElements(), maxNumElementsToExtract));
537 }
538
539 // Do not process warp ops that contain only TransferWriteOps.
540 if (llvm::all_of(warpOp.getOps(),
541 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
542 return failure();
543
544 SmallVector<Value> yieldValues = {writeOp.getVector()};
545 SmallVector<Type> retTypes = {vecType};
546 SmallVector<size_t> newRetIndices;
547 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
548 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
549 rewriter.setInsertionPointAfter(newWarpOp);
550
551 // Create a second warp op that contains only writeOp.
552 auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc, TypeRange(),
553 newWarpOp.getLaneid(),
554 newWarpOp.getWarpSize());
555 Block &body = secondWarpOp.getBodyRegion().front();
556 rewriter.setInsertionPointToStart(&body);
557 auto newWriteOp =
558 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
559 newWriteOp.getValueToStoreMutable().assign(
560 newWarpOp.getResult(newRetIndices[0]));
561 rewriter.eraseOp(writeOp);
562 gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
563 return success();
564 }
565
566 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
567 PatternRewriter &rewriter) const override {
568 gpu::YieldOp yield = warpOp.getTerminator();
569 Operation *lastNode = yield->getPrevNode();
570 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
571 if (!writeOp)
572 return failure();
573
574 Value maybeMask = writeOp.getMask();
575 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
576 return writeOp.getVector() == value ||
577 (maybeMask && maybeMask == value) ||
578 warpOp.isDefinedOutsideOfRegion(value);
579 }))
580 return failure();
581
582 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
583 return success();
584
585 // Masked writes not supported for extraction.
586 if (writeOp.getMask())
587 return failure();
588
589 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
590 return success();
591
592 return failure();
593 }
594
595private:
596 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
597 /// execute op with the proper return type. The new write op is updated to
598 /// write the result of the new warp execute op. The old `writeOp` is deleted.
599 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
600 WarpExecuteOnLane0Op warpOp,
601 vector::TransferWriteOp writeOp,
602 VectorType targetType,
603 VectorType maybeMaskType) const {
604 assert(writeOp->getParentOp() == warpOp &&
605 "write must be nested immediately under warp");
606 OpBuilder::InsertionGuard g(rewriter);
607 SmallVector<size_t> newRetIndices;
608 WarpExecuteOnLane0Op newWarpOp;
609 if (maybeMaskType) {
611 rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
612 TypeRange{targetType, maybeMaskType}, newRetIndices);
613 } else {
615 rewriter, warpOp, ValueRange{{writeOp.getVector()}},
616 TypeRange{targetType}, newRetIndices);
617 }
618 rewriter.setInsertionPointAfter(newWarpOp);
619 auto newWriteOp =
620 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
621 rewriter.eraseOp(writeOp);
622 newWriteOp.getValueToStoreMutable().assign(
623 newWarpOp.getResult(newRetIndices[0]));
624 if (maybeMaskType)
625 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
626 return newWriteOp;
627 }
628
629 DistributionMapFn distributionMapFn;
630 unsigned maxNumElementsToExtract = 1;
631};
632
633/// Sink out elementwise op feeding into a warp op yield.
634/// ```
635/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
636/// ...
637/// %3 = arith.addf %1, %2 : vector<32xf32>
638/// gpu.yield %3 : vector<32xf32>
639/// }
640/// ```
641/// To
642/// ```
643/// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
644/// vector<1xf32>, vector<1xf32>) {
645/// ...
646/// %4 = arith.addf %2, %3 : vector<32xf32>
647/// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
648/// vector<32xf32>
649/// }
650/// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
651struct WarpOpElementwise : public WarpDistributionPattern {
652 using Base::Base;
653 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
654 PatternRewriter &rewriter) const override {
655 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
657 });
658 if (!yieldOperand)
659 return failure();
660
661 Operation *elementWise = yieldOperand->get().getDefiningOp();
662 unsigned operandIndex = yieldOperand->getOperandNumber();
663 Value distributedVal = warpOp.getResult(operandIndex);
664 SmallVector<Value> yieldValues;
665 SmallVector<Type> retTypes;
666 Location loc = warpOp.getLoc();
667 for (OpOperand &operand : elementWise->getOpOperands()) {
668 Type targetType;
669 if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
670 // If the result type is a vector, the operands must also be vectors.
671 auto operandType = cast<VectorType>(operand.get().getType());
672 targetType =
673 VectorType::get(vecType.getShape(), operandType.getElementType());
674 } else {
675 auto operandType = operand.get().getType();
676 assert(!isa<VectorType>(operandType) &&
677 "unexpected yield of vector from op with scalar result type");
678 targetType = operandType;
679 }
680 retTypes.push_back(targetType);
681 yieldValues.push_back(operand.get());
682 }
683 SmallVector<size_t> newRetIndices;
684 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
685 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
686 rewriter.setInsertionPointAfter(newWarpOp);
687 SmallVector<Value> newOperands(elementWise->getOperands().begin(),
688 elementWise->getOperands().end());
689 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
690 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
691 }
692 OpBuilder::InsertionGuard g(rewriter);
693 rewriter.setInsertionPointAfter(newWarpOp);
694 Operation *newOp = cloneOpWithOperandsAndTypes(
695 rewriter, loc, elementWise, newOperands,
696 {newWarpOp.getResult(operandIndex).getType()});
697 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
698 newOp->getResult(0));
699 return success();
700 }
701};
702
703/// Sink out splat constant op feeding into a warp op yield.
704/// ```
705/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
706/// ...
707/// %cst = arith.constant dense<2.0> : vector<32xf32>
708/// gpu.yield %cst : vector<32xf32>
709/// }
710/// ```
711/// To
712/// ```
713/// gpu.warp_execute_on_lane_0(%arg0 {
714/// ...
715/// }
716/// %0 = arith.constant dense<2.0> : vector<1xf32>
717struct WarpOpConstant : public WarpDistributionPattern {
718 using Base::Base;
719 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
720 PatternRewriter &rewriter) const override {
721 OpOperand *yieldOperand =
722 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
723 if (!yieldOperand)
724 return failure();
725 auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
726 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
727 if (!dense)
728 return failure();
729 // Notify the rewriter that the warp op is changing (see the comment on
730 // the WarpOpTransferRead pattern).
731 rewriter.startOpModification(warpOp);
732 unsigned operandIndex = yieldOperand->getOperandNumber();
733 Attribute scalarAttr = dense.getSplatValue<Attribute>();
734 auto newAttr = DenseElementsAttr::get(
735 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
736 Location loc = warpOp.getLoc();
737 rewriter.setInsertionPointAfter(warpOp);
738 Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
739 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
740 rewriter.finalizeOpModification(warpOp);
741 return success();
742 }
743};
744
745/// Sink out step op feeding into a warp op yield.
746/// Vector step op is treated similar to arith.constant, apart from
747/// the result that represents a sequence [0, vec_size).
748/// Due to the to vec_size == warp_size limitation,
749/// we can simply wrap the lane id into a vector (i.e., broadcast).
750/// Supporting vec_size != warp_size may involve preserving the step
751/// result and using additional arith ops (the exact details are TBD).
752/// ```
753/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
754/// ...
755/// %cst = vector.step : vector<32xindex>
756/// gpu.yield %cst : vector<1xindex>
757/// }
758/// ```
759/// To
760/// ```
761/// gpu.warp_execute_on_lane_0(%arg0) {
762/// ...
763/// }
764/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
765struct WarpOpStep final : public WarpDistributionPattern {
766 using Base::Base;
767 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
768 PatternRewriter &rewriter) const override {
769 OpOperand *yieldOperand =
770 getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
771 if (!yieldOperand)
772 return failure();
773 const unsigned operandIdx = yieldOperand->getOperandNumber();
774 auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
775 VectorType resTy = stepOp.getResult().getType();
776 if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize()))
777 return rewriter.notifyMatchFailure(
778 warpOp,
779 llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
780 resTy.getNumElements(), warpOp.getWarpSize()));
781 VectorType newVecTy =
782 cast<VectorType>(warpOp.getResult(operandIdx).getType());
783 rewriter.setInsertionPointAfter(warpOp);
784 Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
785 newVecTy, warpOp.getLaneid());
786 rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
787 return success();
788 }
789};
790
791/// Sink out transfer_read op feeding into a warp op yield.
792/// ```
793/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
794/// ...
795// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
796// vector<32xf32>
797/// gpu.yield %2 : vector<32xf32>
798/// }
799/// ```
800/// To
801/// ```
802/// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
803/// vector<1xf32>, vector<1xf32>) {
804/// ...
805/// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
806/// vector<32xf32> gpu.yield %2 : vector<32xf32>
807/// }
808/// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
809struct WarpOpTransferRead : public WarpDistributionPattern {
810 using Base::Base;
811 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
812 PatternRewriter &rewriter) const override {
813 // Try to find a distributable yielded read. Note that this pattern can
814 // still fail at the end after distribution, in which case this might have
815 // missed another distributable read.
816 OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
817 // Don't duplicate transfer_read ops when distributing.
818 return isa<vector::TransferReadOp>(op) && op->hasOneUse();
819 });
820 if (!operand)
821 return rewriter.notifyMatchFailure(
822 warpOp, "warp result is not a vector.transfer_read op");
823 auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
824
825 // Source must be defined outside of the region.
826 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
827 return rewriter.notifyMatchFailure(
828 read, "source must be defined outside of the region");
829
830 unsigned operandIndex = operand->getOperandNumber();
831 Value distributedVal = warpOp.getResult(operandIndex);
832
833 SmallVector<Value, 4> indices(read.getIndices().begin(),
834 read.getIndices().end());
835 auto sequentialType = cast<VectorType>(read.getResult().getType());
836 auto distributedType = cast<VectorType>(distributedVal.getType());
837 AffineMap map = calculateImplicitMap(sequentialType, distributedType);
838 AffineMap indexMap = map.compose(read.getPermutationMap());
839
840 // Try to delinearize the lane ID to match the rank expected for
841 // distribution.
842 SmallVector<Value> delinearizedIds;
843 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
844 distributedType.getShape(), warpOp.getWarpSize(),
845 warpOp.getLaneid(), delinearizedIds)) {
846 return rewriter.notifyMatchFailure(
847 read, "cannot delinearize lane ID for distribution");
848 }
849 assert(!delinearizedIds.empty() || map.getNumResults() == 0);
850
851 // Distribute indices and the mask (if present).
852 OpBuilder::InsertionGuard g(rewriter);
853 SmallVector<Value> additionalResults(indices.begin(), indices.end());
854 SmallVector<Type> additionalResultTypes(indices.size(),
855 rewriter.getIndexType());
856 additionalResults.push_back(read.getPadding());
857 additionalResultTypes.push_back(read.getPadding().getType());
858
859 bool hasMask = false;
860 if (read.getMask()) {
861 hasMask = true;
862 // TODO: Distribution of masked reads with non-trivial permutation maps
863 // requires the distribution of the mask to elementwise match the
864 // distribution of the permuted written vector. Currently the details
865 // of which lane is responsible for which element is captured strictly
866 // by shape information on the warp op, and thus requires materializing
867 // the permutation in IR.
868 if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
869 return rewriter.notifyMatchFailure(
870 read, "non-trivial permutation maps not supported");
871 VectorType maskType =
872 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
873 additionalResults.push_back(read.getMask());
874 additionalResultTypes.push_back(maskType);
875 }
876
877 SmallVector<size_t> newRetIndices;
878 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
879 rewriter, warpOp, additionalResults, additionalResultTypes,
880 newRetIndices);
881 distributedVal = newWarpOp.getResult(operandIndex);
882
883 // Distributed indices were appended first.
884 SmallVector<Value> newIndices;
885 for (int64_t i = 0, e = indices.size(); i < e; ++i)
886 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
887
888 rewriter.setInsertionPointAfter(newWarpOp);
889 for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
890 AffineExpr d0, d1;
891 bindDims(read.getContext(), d0, d1);
892 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
893 if (!indexExpr)
894 continue;
895 unsigned indexPos = indexExpr.getPosition();
896 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
897 int64_t scale = distributedType.getDimSize(vectorPos);
898 newIndices[indexPos] = affine::makeComposedAffineApply(
899 rewriter, read.getLoc(), d0 + scale * d1,
900 {newIndices[indexPos], delinearizedIds[vectorPos]});
901 }
902
903 // Distributed padding value was appended right after the indices.
904 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
905 // Distributed mask value was added at the end (if the op has a mask).
906 Value newMask =
907 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
908 : Value();
909 auto newRead = vector::TransferReadOp::create(
910 rewriter, read.getLoc(), distributedVal.getType(), read.getBase(),
911 newIndices, read.getPermutationMapAttr(), newPadding, newMask,
912 read.getInBoundsAttr());
913
914 rewriter.replaceAllUsesWith(distributedVal, newRead);
915 return success();
916 }
917};
918
919/// Remove any result that has no use along with the matching yieldOp operand.
920// TODO: Move this in WarpExecuteOnLane0Op canonicalization.
921struct WarpOpDeadResult : public WarpDistributionPattern {
922 using Base::Base;
923 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
924 PatternRewriter &rewriter) const override {
925 SmallVector<Type> newResultTypes;
926 newResultTypes.reserve(warpOp->getNumResults());
927 SmallVector<Value> newYieldValues;
928 newYieldValues.reserve(warpOp->getNumResults());
929 DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
930 DenseMap<OpResult, int64_t> dedupResultPositionMap;
931 gpu::YieldOp yield = warpOp.getTerminator();
932
933 // Some values may be yielded multiple times and correspond to multiple
934 // results. Deduplicating occurs by taking each result with its matching
935 // yielded value, and:
936 // 1. recording the unique first position at which the value with uses is
937 // yielded.
938 // 2. recording for the result, the first position at which the dedup'ed
939 // value is yielded.
940 // 3. skipping from the new result types / new yielded values any result
941 // that has no use or whose yielded value has already been seen.
942 for (OpResult result : warpOp.getResults()) {
943 if (result.use_empty())
944 continue;
945 Value yieldOperand = yield.getOperand(result.getResultNumber());
946 auto it = dedupYieldOperandPositionMap.insert(
947 std::make_pair(yieldOperand, newResultTypes.size()));
948 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
949 if (!it.second)
950 continue;
951 newResultTypes.push_back(result.getType());
952 newYieldValues.push_back(yieldOperand);
953 }
954 // No modification, exit early.
955 if (yield.getNumOperands() == newYieldValues.size())
956 return failure();
957 // Move the body of the old warpOp to a new warpOp.
958 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
959 rewriter, warpOp, newYieldValues, newResultTypes);
960
961 // Simplify the new warp op after dropping dead results.
962 newWarpOp.getBody()->walk([&](Operation *op) {
963 if (isOpTriviallyDead(op))
964 rewriter.eraseOp(op);
965 });
966
967 // Replace results of the old warpOp by the new, deduplicated results.
968 SmallVector<Value> newValues;
969 newValues.reserve(warpOp->getNumResults());
970 for (OpResult result : warpOp.getResults()) {
971 if (result.use_empty())
972 newValues.push_back(Value());
973 else
974 newValues.push_back(
975 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
976 }
977 rewriter.replaceOp(warpOp, newValues);
978 return success();
979 }
980};
981
982// If an operand is directly yielded out of the region we can forward it
983// directly and it doesn't need to go through the region.
984struct WarpOpForwardOperand : public WarpDistributionPattern {
985 using Base::Base;
986 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
987 PatternRewriter &rewriter) const override {
988 gpu::YieldOp yield = warpOp.getTerminator();
989 Value valForwarded;
990 unsigned resultIndex;
991 for (OpOperand &operand : yield->getOpOperands()) {
992 Value result = warpOp.getResult(operand.getOperandNumber());
993 if (result.use_empty())
994 continue;
995
996 // Assume all the values coming from above are uniform.
997 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
998 if (result.getType() != operand.get().getType())
999 continue;
1000 valForwarded = operand.get();
1001 resultIndex = operand.getOperandNumber();
1002 break;
1003 }
1004 auto arg = dyn_cast<BlockArgument>(operand.get());
1005 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1006 continue;
1007 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1008 if (result.getType() != warpOperand.getType())
1009 continue;
1010 valForwarded = warpOperand;
1011 resultIndex = operand.getOperandNumber();
1012 break;
1013 }
1014 if (!valForwarded)
1015 return failure();
1016 // Notify the rewriter that the warp op is changing (see the comment on
1017 // the WarpOpTransferRead pattern).
1018 rewriter.startOpModification(warpOp);
1019 rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
1020 rewriter.finalizeOpModification(warpOp);
1021 return success();
1022 }
1023};
1024
1025struct WarpOpBroadcast : public WarpDistributionPattern {
1026 using Base::Base;
1027 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1028 PatternRewriter &rewriter) const override {
1029 OpOperand *operand =
1030 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1031 if (!operand)
1032 return failure();
1033 unsigned int operandNumber = operand->getOperandNumber();
1034 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
1035 Location loc = broadcastOp.getLoc();
1036 auto destVecType =
1037 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1038 Value broadcastSrc = broadcastOp.getSource();
1039 Type broadcastSrcType = broadcastSrc.getType();
1040
1041 // Check that the broadcast actually spans a set of values uniformly across
1042 // all threads. In other words, check that each thread can reconstruct
1043 // their own broadcast.
1044 // For that we simply check that the broadcast we want to build makes sense.
1045 if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
1046 vector::BroadcastableToResult::Success)
1047 return failure();
1048 SmallVector<size_t> newRetIndices;
1049 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1050 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1051 rewriter.setInsertionPointAfter(newWarpOp);
1052 Value broadcasted = vector::BroadcastOp::create(
1053 rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1054 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1055 broadcasted);
1056 return success();
1057 }
1058};
1059
1060/// Pattern to move shape cast out of the warp op. shape cast is basically a
1061/// no-op for warp distribution; we need to handle the shape though.
1062struct WarpOpShapeCast : public WarpDistributionPattern {
1063 using Base::Base;
1064 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1065 PatternRewriter &rewriter) const override {
1066 OpOperand *operand =
1067 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1068 if (!operand)
1069 return failure();
1070
1071 auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
1072
1073 unsigned int operandNumber = operand->getOperandNumber();
1074 auto castDistributedType =
1075 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1076 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1077 VectorType castResultType = castDistributedType;
1078
1079 // We expect the distributed type to have a smaller rank than the original
1080 // type. Prepend with size-one dimensions to make them the same.
1081 unsigned castDistributedRank = castDistributedType.getRank();
1082 unsigned castOriginalRank = castOriginalType.getRank();
1083 if (castDistributedRank < castOriginalRank) {
1084 SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1085 llvm::append_range(shape, castDistributedType.getShape());
1086 castDistributedType =
1087 VectorType::get(shape, castDistributedType.getElementType());
1088 }
1089
1090 SmallVector<size_t> newRetIndices;
1091 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1092 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1093 newRetIndices);
1094 rewriter.setInsertionPointAfter(newWarpOp);
1095 Value newCast = vector::ShapeCastOp::create(
1096 rewriter, oldCastOp.getLoc(), castResultType,
1097 newWarpOp->getResult(newRetIndices[0]));
1098 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1099 return success();
1100 }
1101};
1102
1103/// Sink out vector.create_mask op feeding into a warp op yield.
1104/// ```
1105/// %0 = ...
1106/// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1107/// ...
1108/// %mask = vector.create_mask %0 : vector<32xi1>
1109/// gpu.yield %mask : vector<32xi1>
1110/// }
1111/// ```
1112/// To
1113/// ```
1114/// %0 = ...
1115/// gpu.warp_execute_on_lane_0(%arg0) {
1116/// ...
1117/// }
1118/// %cmp = arith.cmpi ult, %laneid, %0
1119/// %ub = arith.select %cmp, %c0, %c1
1120/// %1 = vector.create_mask %ub : vector<1xi1>
1121struct WarpOpCreateMask : public WarpDistributionPattern {
1122 using Base::Base;
1123 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1124 PatternRewriter &rewriter) const override {
1125 OpOperand *yieldOperand =
1126 getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1127 if (!yieldOperand)
1128 return failure();
1129
1130 auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1131
1132 // Early exit if any values needed for calculating the new mask indices
1133 // are defined inside the warp op.
1134 if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1135 return warpOp.isDefinedOutsideOfRegion(value);
1136 }))
1137 return failure();
1138
1139 Location loc = mask.getLoc();
1140 unsigned operandIndex = yieldOperand->getOperandNumber();
1141
1142 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1143 VectorType seqType = mask.getVectorType();
1144 ArrayRef<int64_t> seqShape = seqType.getShape();
1145 ArrayRef<int64_t> distShape = distType.getShape();
1146
1147 rewriter.setInsertionPointAfter(warpOp);
1148
1149 // Delinearize the lane ID for constructing the distributed mask sizes.
1150 SmallVector<Value> delinearizedIds;
1151 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1152 warpOp.getWarpSize(), warpOp.getLaneid(),
1153 delinearizedIds))
1154 return rewriter.notifyMatchFailure(
1155 mask, "cannot delinearize lane ID for distribution");
1156 assert(!delinearizedIds.empty());
1157
1158 // Notify the rewriter that the warp op is changing (see the comment on
1159 // the WarpOpTransferRead pattern).
1160 rewriter.startOpModification(warpOp);
1161
1162 AffineExpr s0, s1;
1163 bindSymbols(rewriter.getContext(), s0, s1);
1164 SmallVector<Value> newOperands;
1165 for (int i = 0, e = distShape.size(); i < e; ++i) {
1166 // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1167 // find the distance from the largest mask index owned by this lane to the
1168 // original mask size. `vector.create_mask` implicitly clamps mask
1169 // operands to the range [0, mask_vector_size[i]], or in other words, the
1170 // mask sizes are always in the range [0, mask_vector_size[i]).
1171 Value maskDimIdx = affine::makeComposedAffineApply(
1172 rewriter, loc, s1 - s0 * distShape[i],
1173 {delinearizedIds[i], mask.getOperand(i)});
1174 newOperands.push_back(maskDimIdx);
1175 }
1176
1177 auto newMask =
1178 vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1179 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1180 rewriter.finalizeOpModification(warpOp);
1181 return success();
1182 }
1183};
1184
1185/// Sink out insert_strided_slice op feeding into a warp op yield.
1186/// ```
1187/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
1188/// ...
1189/// %src = ... : vector<4x32xf32>
1190/// %dest = ... : vector<8x32xf32>
1191/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1192/// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
1193/// gpu.yield %insert : vector<8x32xf32>
1194/// }
1195/// ```
1196/// To
1197/// ```
1198/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
1199/// vector<8x1xf32>) {
1200/// ...
1201/// %src = ... : vector<4x32xf32>
1202/// %dest = ... : vector<8x32xf32>
1203/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
1204/// }
1205/// %insert = vector.insert_strided_slice %0#0, %0#1,
1206/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
1207/// ```
1208/// NOTE: Current support assumes that both src and dest vectors are distributed
1209/// to lanes and sinking the insert op does not require any cross lane
1210/// communication.
1211struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
1212 using Base::Base;
1213 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1214 PatternRewriter &rewriter) const override {
1215 OpOperand *operand =
1216 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1217 if (!operand)
1218 return failure();
1219 unsigned int operandNumber = operand->getOperandNumber();
1220 auto insertOp =
1221 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1222 auto distributedType =
1223 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1224 // Distributed type must be 2D or higher.
1225 // TODO: Support 1D distributed types.
1226 if (distributedType.getRank() < 2)
1227 return rewriter.notifyMatchFailure(
1228 insertOp, "result vector type must be 2D or higher");
1229 // Find the distributed dimension of the dest vector. There should be
1230 // exactly one.
1231 auto yieldedType = cast<VectorType>(operand->get().getType());
1232 int64_t destDistributedDim =
1233 getDistributedDim(yieldedType, distributedType);
1234 assert(destDistributedDim != -1 && "could not find distributed dimension");
1235
1236 VectorType srcType = insertOp.getSourceVectorType();
1237 VectorType destType = insertOp.getDestVectorType();
1238 // Currently we require that both source (kD) and dest (nD) vectors are
1239 // distributed. This requires that distributedDim (d) is contained in the
1240 // last k dims of the dest vector (d >= n - k).
1241 // TODO: Add support for case where source vector is not distributed.
1242 int64_t sourceDistributedDim =
1243 destDistributedDim - (destType.getRank() - srcType.getRank());
1244 if (sourceDistributedDim < 0)
1245 return rewriter.notifyMatchFailure(
1246 insertOp,
1247 "distributed dimension must be in the last k dims of dest vector");
1248 // Distributed dimension must be fully inserted.
1249 if (srcType.getDimSize(sourceDistributedDim) !=
1250 destType.getDimSize(destDistributedDim))
1251 return rewriter.notifyMatchFailure(
1252 insertOp, "distributed dimension must be fully inserted");
1253 SmallVector<int64_t> newSourceDistShape(
1254 insertOp.getSourceVectorType().getShape());
1255 newSourceDistShape[sourceDistributedDim] =
1256 distributedType.getDimSize(destDistributedDim);
1257 auto newSourceTy =
1258 VectorType::get(newSourceDistShape, distributedType.getElementType());
1259 VectorType newDestTy = distributedType;
1260 SmallVector<size_t> newRetIndices;
1261 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1262 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1263 {newSourceTy, newDestTy}, newRetIndices);
1264 rewriter.setInsertionPointAfter(newWarpOp);
1265 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1266 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1267 // Create a new insert strided slice op that inserts distributed source into
1268 // distributed dest.
1269 Value newInsert = vector::InsertStridedSliceOp::create(
1270 rewriter, insertOp.getLoc(), distributedDest.getType(),
1271 distributedSource, distributedDest, insertOp.getOffsets(),
1272 insertOp.getStrides());
1273 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
1274 return success();
1275 }
1276};
1277
1278/// Sink out extract_strided_slice op feeding into a warp op yield.
1279/// ```
1280/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
1281/// ...
1282/// %src = ... : vector<64x32xf32>
1283/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1284/// strides = [1] : vector<64x32xf32> to vector<16x32xf32>
1285/// gpu.yield %extract : vector<16x32xf32>
1286/// }
1287/// ```
1288/// To
1289/// ```
1290/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
1291/// ...
1292/// %src = ... : vector<64x32xf32>
1293/// gpu.yield %src : vector<64x32xf32>
1294/// }
1295/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1296/// strides = [1] : vector<64x1xf32> to vector<16x1xf32>
1297/// ```
1298/// NOTE: Current support assumes that the extraction happens only on non
1299/// distributed dimensions (does not require cross lane communication).
1300struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
1301 using Base::Base;
1302 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1303 PatternRewriter &rewriter) const override {
1304 OpOperand *operand =
1305 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1306 if (!operand)
1307 return failure();
1308 unsigned int operandNumber = operand->getOperandNumber();
1309 auto extractOp =
1310 operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
1311 auto distributedType =
1312 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1313 // Distributed type must be 2D or higher.
1314 // TODO: Support 1D distributed types.
1315 if (distributedType.getRank() < 2)
1316 return rewriter.notifyMatchFailure(
1317 extractOp, "result vector type must be 2D or higher");
1318
1319 // Find the distributed dimension. There should be exactly one.
1320 auto yieldedType = cast<VectorType>(operand->get().getType());
1321 int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1322 assert(distributedDim != -1 && "could not find distributed dimension");
1323
1324 int64_t numOfExtractedDims =
1325 static_cast<int64_t>(extractOp.getSizes().size());
1326 // If the distributed dim is included in the extracted dims, then we make
1327 // sure distributed dim is fully extracted. If distributed dim is not
1328 // included in extracted dims, it is guaranteed to be fully extracted (i.e.
1329 // distributed dim comes after all the extracted dims)
1330 // TODO: Partial extraction from distributed dimension require cross lane
1331 // communication.
1332 if (distributedDim < numOfExtractedDims) {
1333 int64_t distributedDimOffset =
1334 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1335 .getInt();
1336 int64_t distributedDimSize =
1337 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1338 .getInt();
1339 if (distributedDimOffset != 0 ||
1340 distributedDimSize != yieldedType.getDimSize(distributedDim))
1341 return rewriter.notifyMatchFailure(
1342 extractOp, "distributed dimension must be fully extracted");
1343 }
1344 SmallVector<int64_t> newDistributedShape(
1345 extractOp.getSourceVectorType().getShape());
1346 newDistributedShape[distributedDim] =
1347 distributedType.getDimSize(distributedDim);
1348 auto newDistributedType =
1349 VectorType::get(newDistributedShape, distributedType.getElementType());
1350 SmallVector<size_t> newRetIndices;
1351 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1352 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1353 newRetIndices);
1354 rewriter.setInsertionPointAfter(newWarpOp);
1355 SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1356 extractOp.getSizes(), [](Attribute attr) { return attr; });
1357 // Update the distributed sizes to match the distributed type.
1358 if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
1359 distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1360 distributedType.getDimSize(distributedDim));
1361
1362 // Create a new extract strided slice op that extracts from the
1363 // distributed vector.
1364 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1365 Value newExtract = vector::ExtractStridedSliceOp::create(
1366 rewriter, extractOp.getLoc(), distributedType, distributedVec,
1367 extractOp.getOffsets(),
1368 ArrayAttr::get(rewriter.getContext(), distributedSizes),
1369 extractOp.getStrides());
1370 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1371 newExtract);
1372 return success();
1373 }
1374};
1375
1376/// Pattern to move out vector.extract of single element vector. Those don't
1377/// need to be distributed and can just be propagated outside of the region.
1378struct WarpOpExtract : public WarpDistributionPattern {
1379 using Base::Base;
1380 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1381 PatternRewriter &rewriter) const override {
1382 OpOperand *operand =
1383 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1384 if (!operand)
1385 return failure();
1386 unsigned int operandNumber = operand->getOperandNumber();
1387 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1388 VectorType extractSrcType = extractOp.getSourceVectorType();
1389 Location loc = extractOp.getLoc();
1390
1391 // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1392 if (extractSrcType.getRank() <= 1) {
1393 return failure();
1394 }
1395
1396 // All following cases are 2d or higher dimensional source vectors.
1397
1398 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1399 // There is no distribution, this is a broadcast. Simply move the extract
1400 // out of the warp op.
1401 // TODO: This could be optimized. E.g., in case of a scalar result, let
1402 // one lane extract and shuffle the result to all other lanes (same as
1403 // the 1d case).
1404 SmallVector<size_t> newRetIndices;
1405 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1406 rewriter, warpOp, {extractOp.getSource()},
1407 {extractOp.getSourceVectorType()}, newRetIndices);
1408 rewriter.setInsertionPointAfter(newWarpOp);
1409 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1410 // Extract from distributed vector.
1411 Value newExtract = vector::ExtractOp::create(
1412 rewriter, loc, distributedVec, extractOp.getMixedPosition());
1413 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1414 newExtract);
1415 return success();
1416 }
1417
1418 // Find the distributed dimension. There should be exactly one.
1419 auto distributedType =
1420 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1421 auto yieldedType = cast<VectorType>(operand->get().getType());
1422 int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1423 assert(distributedDim != -1 && "could not find distributed dimension");
1424 (void)distributedDim;
1425
1426 // Yield source vector from warp op.
1427 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1428 for (int i = 0; i < distributedType.getRank(); ++i)
1429 newDistributedShape[i + extractOp.getNumIndices()] =
1430 distributedType.getDimSize(i);
1431 auto newDistributedType =
1432 VectorType::get(newDistributedShape, distributedType.getElementType());
1433 SmallVector<size_t> newRetIndices;
1434 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1435 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1436 newRetIndices);
1437 rewriter.setInsertionPointAfter(newWarpOp);
1438 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1439 // Extract from distributed vector.
1440 Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1441 extractOp.getMixedPosition());
1442 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1443 newExtract);
1444 return success();
1445 }
1446};
1447
1448/// Pattern to move out vector.extract with a scalar result.
1449/// Only supports 1-D and 0-D sources for now.
1450struct WarpOpExtractScalar : public WarpDistributionPattern {
1451 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1452 PatternBenefit b = 1)
1453 : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1454 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1455 PatternRewriter &rewriter) const override {
1456 OpOperand *operand =
1457 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1458 if (!operand)
1459 return failure();
1460 unsigned int operandNumber = operand->getOperandNumber();
1461 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1462 VectorType extractSrcType = extractOp.getSourceVectorType();
1463 // Only supports 1-D or 0-D sources for now.
1464 if (extractSrcType.getRank() > 1) {
1465 return rewriter.notifyMatchFailure(
1466 extractOp, "only 0-D or 1-D source supported for now");
1467 }
1468 // TODO: Supported shuffle types should be parameterizable, similar to
1469 // `WarpShuffleFromIdxFn`.
1470 if (!extractSrcType.getElementType().isF32() &&
1471 !extractSrcType.getElementType().isInteger(32))
1472 return rewriter.notifyMatchFailure(
1473 extractOp, "only f32/i32 element types are supported");
1474 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1475 Type elType = extractSrcType.getElementType();
1476 VectorType distributedVecType;
1477 if (!is0dOrVec1Extract) {
1478 assert(extractSrcType.getRank() == 1 &&
1479 "expected that extract src rank is 0 or 1");
1480 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1481 return failure();
1482 int64_t elementsPerLane =
1483 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1484 distributedVecType = VectorType::get({elementsPerLane}, elType);
1485 } else {
1486 distributedVecType = extractSrcType;
1487 }
1488 // Yield source vector and position (if present) from warp op.
1489 SmallVector<Value> additionalResults{extractOp.getSource()};
1490 SmallVector<Type> additionalResultTypes{distributedVecType};
1491 additionalResults.append(
1492 SmallVector<Value>(extractOp.getDynamicPosition()));
1493 additionalResultTypes.append(
1494 SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1495
1496 Location loc = extractOp.getLoc();
1497 SmallVector<size_t> newRetIndices;
1498 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1499 rewriter, warpOp, additionalResults, additionalResultTypes,
1500 newRetIndices);
1501 rewriter.setInsertionPointAfter(newWarpOp);
1502 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1503
1504 // 0d extract: The new warp op broadcasts the source vector to all lanes.
1505 // All lanes extract the scalar.
1506 if (is0dOrVec1Extract) {
1507 Value newExtract;
1508 SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1509 newExtract =
1510 vector::ExtractOp::create(rewriter, loc, distributedVec, indices);
1511 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1512 newExtract);
1513 return success();
1514 }
1515
1516 int64_t staticPos = extractOp.getStaticPosition()[0];
1517 OpFoldResult pos = ShapedType::isDynamic(staticPos)
1518 ? (newWarpOp->getResult(newRetIndices[1]))
1519 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1520 // 1d extract: Distribute the source vector. One lane extracts and shuffles
1521 // the value to all other lanes.
1522 int64_t elementsPerLane = distributedVecType.getShape()[0];
1523 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1524 // tid of extracting thread: pos / elementsPerLane
1525 Value broadcastFromTid = affine::makeComposedAffineApply(
1526 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1527 // Extract at position: pos % elementsPerLane
1528 Value newPos =
1529 elementsPerLane == 1
1530 ? arith::ConstantIndexOp::create(rewriter, loc, 0).getResult()
1531 : affine::makeComposedAffineApply(rewriter, loc,
1532 sym0 % elementsPerLane, pos);
1533 Value extracted =
1534 vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1535
1536 // Shuffle the extracted value to all lanes.
1537 Value shuffled = warpShuffleFromIdxFn(
1538 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1539 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1540 return success();
1541 }
1542
1543private:
1544 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1545};
1546
1547/// Pattern to move out vector.insert with a scalar input.
1548/// Only supports 1-D and 0-D destinations for now.
1549struct WarpOpInsertScalar : public WarpDistributionPattern {
1550 using Base::Base;
1551 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1552 PatternRewriter &rewriter) const override {
1553 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1554 if (!operand)
1555 return failure();
1556 unsigned int operandNumber = operand->getOperandNumber();
1557 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1558 VectorType vecType = insertOp.getDestVectorType();
1559 VectorType distrType =
1560 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1561
1562 // Only supports 1-D or 0-D destinations for now.
1563 if (vecType.getRank() > 1) {
1564 return rewriter.notifyMatchFailure(
1565 insertOp, "only 0-D or 1-D source supported for now");
1566 }
1567
1568 // Yield destination vector, source scalar and position from warp op.
1569 SmallVector<Value> additionalResults{insertOp.getDest(),
1570 insertOp.getValueToStore()};
1571 SmallVector<Type> additionalResultTypes{
1572 distrType, insertOp.getValueToStore().getType()};
1573 additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1574 additionalResultTypes.append(
1575 SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1576
1577 Location loc = insertOp.getLoc();
1578 SmallVector<size_t> newRetIndices;
1579 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1580 rewriter, warpOp, additionalResults, additionalResultTypes,
1581 newRetIndices);
1582 rewriter.setInsertionPointAfter(newWarpOp);
1583 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1584 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1585 rewriter.setInsertionPointAfter(newWarpOp);
1586
1587 OpFoldResult pos;
1588 if (vecType.getRank() != 0) {
1589 int64_t staticPos = insertOp.getStaticPosition()[0];
1590 pos = ShapedType::isDynamic(staticPos)
1591 ? (newWarpOp->getResult(newRetIndices[2]))
1592 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1593 }
1594
1595 // This condition is always true for 0-d vectors.
1596 if (vecType == distrType) {
1597 Value newInsert;
1598 SmallVector<OpFoldResult> indices;
1599 if (pos) {
1600 indices.push_back(pos);
1601 }
1602 newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1603 distributedVec, indices);
1604 // Broadcast: Simply move the vector.insert op out.
1605 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1606 newInsert);
1607 return success();
1608 }
1609
1610 // This is a distribution. Only one lane should insert.
1611 int64_t elementsPerLane = distrType.getShape()[0];
1612 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1613 // tid of extracting thread: pos / elementsPerLane
1614 Value insertingLane = affine::makeComposedAffineApply(
1615 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1616 // Insert position: pos % elementsPerLane
1617 OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
1618 rewriter, loc, sym0 % elementsPerLane, pos);
1619 Value isInsertingLane =
1620 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1621 newWarpOp.getLaneid(), insertingLane);
1622 Value newResult =
1623 scf::IfOp::create(
1624 rewriter, loc, isInsertingLane,
1625 /*thenBuilder=*/
1626 [&](OpBuilder &builder, Location loc) {
1627 Value newInsert = vector::InsertOp::create(
1628 builder, loc, newSource, distributedVec, newPos);
1629 scf::YieldOp::create(builder, loc, newInsert);
1630 },
1631 /*elseBuilder=*/
1632 [&](OpBuilder &builder, Location loc) {
1633 scf::YieldOp::create(builder, loc, distributedVec);
1634 })
1635 .getResult(0);
1636 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1637 return success();
1638 }
1639};
1640
1641struct WarpOpInsert : public WarpDistributionPattern {
1642 using Base::Base;
1643 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1644 PatternRewriter &rewriter) const override {
1645 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1646 if (!operand)
1647 return failure();
1648 unsigned int operandNumber = operand->getOperandNumber();
1649 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1650 Location loc = insertOp.getLoc();
1651
1652 // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1653 if (insertOp.getDestVectorType().getRank() <= 1) {
1654 return failure();
1655 }
1656
1657 // All following cases are 2d or higher dimensional source vectors.
1658
1659 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1660 // There is no distribution, this is a broadcast. Simply move the insert
1661 // out of the warp op.
1662 SmallVector<size_t> newRetIndices;
1663 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1664 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1665 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1666 newRetIndices);
1667 rewriter.setInsertionPointAfter(newWarpOp);
1668 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1669 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1670 Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1671 distributedDest,
1672 insertOp.getMixedPosition());
1673 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1674 newResult);
1675 return success();
1676 }
1677
1678 // Find the distributed dimension. There should be exactly one.
1679 auto distrDestType =
1680 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1681 auto yieldedType = cast<VectorType>(operand->get().getType());
1682 int64_t distrDestDim = -1;
1683 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1684 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1685 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1686 // support distributing multiple dimensions in the future.
1687 assert(distrDestDim == -1 && "found multiple distributed dims");
1688 distrDestDim = i;
1689 }
1690 }
1691 assert(distrDestDim != -1 && "could not find distributed dimension");
1692
1693 // Compute the distributed source vector type.
1694 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1695 SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1696 // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1697 // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1698 // insert a smaller vector<3xf32>.
1699 // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1700 // case, one lane will insert the source vector<96xf32>. The other
1701 // lanes will not do anything.
1702 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1703 if (distrSrcDim >= 0)
1704 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1705 auto distrSrcType =
1706 VectorType::get(distrSrcShape, distrDestType.getElementType());
1707
1708 // Yield source and dest vectors from warp op.
1709 SmallVector<size_t> newRetIndices;
1710 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1711 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1712 {distrSrcType, distrDestType}, newRetIndices);
1713 rewriter.setInsertionPointAfter(newWarpOp);
1714 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1715 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1716
1717 // Insert into the distributed vector.
1718 Value newResult;
1719 if (distrSrcDim >= 0) {
1720 // Every lane inserts a small piece.
1721 newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1722 distributedDest,
1723 insertOp.getMixedPosition());
1724 } else {
1725 // One lane inserts the entire source vector.
1726 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1727 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1728 SmallVector<int64_t> newPos = getAsIntegers(pos);
1729 // tid of inserting lane: pos / elementsPerLane
1730 Value insertingLane = arith::ConstantIndexOp::create(
1731 rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1732 Value isInsertingLane =
1733 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1734 newWarpOp.getLaneid(), insertingLane);
1735 // Insert position: pos % elementsPerLane
1736 newPos[distrDestDim] %= elementsPerLane;
1737 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1738 Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1739 distributedDest, newPos);
1740 scf::YieldOp::create(builder, loc, newInsert);
1741 };
1742 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1743 scf::YieldOp::create(builder, loc, distributedDest);
1744 };
1745 newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1746 /*thenBuilder=*/insertingBuilder,
1747 /*elseBuilder=*/nonInsertingBuilder)
1748 .getResult(0);
1749 }
1750
1751 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1752 return success();
1753 }
1754};
1755
1756/// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
1757/// the scf.if is the last operation in the region so that it doesn't
1758/// change the order of execution. This creates a new scf.if after the
1759/// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
1760/// the "inner" WarpExecuteOnLane0Op. Example:
1761/// ```
1762/// gpu.warp_execute_on_lane_0(%laneid)[32] {
1763/// %payload = ... : vector<32xindex>
1764/// scf.if %pred {
1765/// vector.store %payload, %buffer[%idx] : memref<128xindex>,
1766/// vector<32xindex>
1767/// }
1768/// gpu.yield
1769/// }
1770/// ```
1771/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
1772/// %payload = ... : vector<32xindex>
1773/// gpu.yield %payload : vector<32xindex>
1774/// }
1775/// scf.if %pred {
1776/// gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
1777/// ^bb0(%arg1: vector<32xindex>):
1778/// vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
1779/// }
1780/// }
1781/// ```
1782struct WarpOpScfIfOp : public WarpDistributionPattern {
1783 WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1784 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1785 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1786 PatternRewriter &rewriter) const override {
1787 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1788 // Only pick up `IfOp` if it is the last op in the region.
1789 Operation *lastNode = warpOpYield->getPrevNode();
1790 auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1791 if (!ifOp)
1792 return failure();
1793
1794 // The current `WarpOp` can yield two types of values:
1795 // 1. Not results of `IfOp`:
1796 // Preserve them in the new `WarpOp`.
1797 // Collect their yield index to remap the usages.
1798 // 2. Results of `IfOp`:
1799 // They are not part of the new `WarpOp` results.
1800 // Map current warp's yield operand index to `IfOp` result idx.
1801 SmallVector<Value> nonIfYieldValues;
1802 SmallVector<unsigned> nonIfYieldIndices;
1803 llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1804 llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1805 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1806 const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
1807 if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
1808 nonIfYieldValues.push_back(yieldOperand.get());
1809 nonIfYieldIndices.push_back(yieldOperandIdx);
1810 continue;
1811 }
1812 OpResult ifResult = cast<OpResult>(yieldOperand.get());
1813 const unsigned ifResultIdx = ifResult.getResultNumber();
1814 ifResultMapping[yieldOperandIdx] = ifResultIdx;
1815 // If this `ifOp` result is vector type and it is yielded by the
1816 // `WarpOp`, we keep track the distributed type for this result.
1817 if (!isa<VectorType>(ifResult.getType()))
1818 continue;
1819 VectorType distType =
1820 cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1821 ifResultDistTypes[ifResultIdx] = distType;
1822 }
1823
1824 // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
1825 // them
1826 auto [escapingValuesThen, escapingValueInputTypesThen,
1827 escapingValueDistTypesThen] =
1828 getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1829 distributionMapFn);
1830 auto [escapingValuesElse, escapingValueInputTypesElse,
1831 escapingValueDistTypesElse] =
1832 getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1833 distributionMapFn);
1834 if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1835 llvm::is_contained(escapingValueDistTypesElse, Type{}))
1836 return failure();
1837
1838 // The new `WarpOp` groups yields values in following order:
1839 // 1. Branch condition
1840 // 2. Escaping values then branch
1841 // 3. Escaping values else branch
1842 // 4. All non-`ifOp` yielded values.
1843 SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1844 newWarpOpYieldValues.append(escapingValuesThen.begin(),
1845 escapingValuesThen.end());
1846 newWarpOpYieldValues.append(escapingValuesElse.begin(),
1847 escapingValuesElse.end());
1848 SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1849 newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1850 escapingValueDistTypesThen.end());
1851 newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1852 escapingValueDistTypesElse.end());
1853
1854 for (auto [idx, val] :
1855 llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1856 newWarpOpYieldValues.push_back(val);
1857 newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1858 }
1859 // Replace the old `WarpOp` with the new one that has additional yield
1860 // values and types.
1861 SmallVector<size_t> newIndices;
1862 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1863 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
1864 // `ifOp` returns the result of the inner warp op.
1865 SmallVector<Type> newIfOpDistResTypes;
1866 for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1867 Type distType = cast<Value>(res).getType();
1868 if (auto vecType = dyn_cast<VectorType>(distType)) {
1869 AffineMap map = distributionMapFn(cast<Value>(res));
1870 // Fallback to affine map if the dist result was not previously recorded
1871 distType = ifResultDistTypes.count(i)
1872 ? ifResultDistTypes[i]
1873 : getDistributedType(vecType, map, warpOp.getWarpSize());
1874 }
1875 newIfOpDistResTypes.push_back(distType);
1876 }
1877 // Create a new `IfOp` outside the new `WarpOp` region.
1878 OpBuilder::InsertionGuard g(rewriter);
1879 rewriter.setInsertionPointAfter(newWarpOp);
1880 auto newIfOp = scf::IfOp::create(
1881 rewriter, ifOp.getLoc(), newIfOpDistResTypes,
1882 newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
1883 static_cast<bool>(ifOp.elseBlock()));
1884 auto encloseRegionInWarpOp =
1885 [&](Block *oldIfBranch, Block *newIfBranch,
1886 llvm::SmallSetVector<Value, 32> &escapingValues,
1887 SmallVector<Type> &escapingValueInputTypes,
1888 size_t warpResRangeStart) {
1889 OpBuilder::InsertionGuard g(rewriter);
1890 if (!newIfBranch)
1891 return;
1892 rewriter.setInsertionPointToStart(newIfBranch);
1893 llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1894 SmallVector<Value> innerWarpInputVals;
1895 SmallVector<Type> innerWarpInputTypes;
1896 for (size_t i = 0; i < escapingValues.size();
1897 ++i, ++warpResRangeStart) {
1898 innerWarpInputVals.push_back(
1899 newWarpOp.getResult(newIndices[warpResRangeStart]));
1900 escapeValToBlockArgIndex[escapingValues[i]] =
1901 innerWarpInputTypes.size();
1902 innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1903 }
1904 auto innerWarp = WarpExecuteOnLane0Op::create(
1905 rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1906 newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1907 innerWarpInputVals, innerWarpInputTypes);
1908
1909 innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
1910 innerWarp.getWarpRegion().addArguments(
1911 innerWarpInputTypes,
1912 SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1913
1914 SmallVector<Value> yieldOperands;
1915 for (Value operand : oldIfBranch->getTerminator()->getOperands())
1916 yieldOperands.push_back(operand);
1917 rewriter.eraseOp(oldIfBranch->getTerminator());
1918
1919 rewriter.setInsertionPointToEnd(innerWarp.getBody());
1920 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1921 rewriter.setInsertionPointAfter(innerWarp);
1922 scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1923
1924 // Update any users of escaping values that were forwarded to the
1925 // inner `WarpOp`. These values are arguments of the inner `WarpOp`.
1926 innerWarp.walk([&](Operation *op) {
1927 for (OpOperand &operand : op->getOpOperands()) {
1928 auto it = escapeValToBlockArgIndex.find(operand.get());
1929 if (it == escapeValToBlockArgIndex.end())
1930 continue;
1931 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1932 }
1933 });
1934 mlir::vector::moveScalarUniformCode(innerWarp);
1935 };
1936 encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
1937 &newIfOp.getThenRegion().front(), escapingValuesThen,
1938 escapingValueInputTypesThen, 1);
1939 if (!ifOp.getElseRegion().empty())
1940 encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
1941 &newIfOp.getElseRegion().front(),
1942 escapingValuesElse, escapingValueInputTypesElse,
1943 1 + escapingValuesThen.size());
1944 // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
1945 // result.
1946 for (auto [origIdx, newIdx] : ifResultMapping)
1947 rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
1948 newIfOp.getResult(newIdx), newIfOp);
1949 return success();
1950 }
1951
1952private:
1953 DistributionMapFn distributionMapFn;
1954};
1955
1956/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1957/// the scf.ForOp is the last operation in the region so that it doesn't
1958/// change the order of execution. This creates a new scf.for region after the
1959/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1960/// WarpExecuteOnLane0Op region. Example:
1961/// ```
1962/// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1963/// ...
1964/// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1965/// -> (vector<128xf32>) {
1966/// ...
1967/// scf.yield %r : vector<128xf32>
1968/// }
1969/// gpu.yield %v1 : vector<128xf32>
1970/// }
1971/// ```
1972/// To:
1973/// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1974/// ...
1975/// gpu.yield %v : vector<128xf32>
1976/// }
1977/// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1978/// -> (vector<4xf32>) {
1979/// %iw = gpu.warp_execute_on_lane_0(%laneid)
1980/// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1981/// ^bb0(%arg: vector<128xf32>):
1982/// ...
1983/// gpu.yield %ir : vector<128xf32>
1984/// }
1985/// scf.yield %iw : vector<4xf32>
1986/// }
1987/// ```
1988struct WarpOpScfForOp : public WarpDistributionPattern {
1989
1990 WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1991 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1992 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1993 PatternRewriter &rewriter) const override {
1994 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1995 // Only pick up `ForOp` if it is the last op in the region.
1996 Operation *lastNode = warpOpYield->getPrevNode();
1997 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1998 if (!forOp)
1999 return failure();
2000 // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
2001 // Those Values need to be returned by the new warp op.
2002 auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2003 getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2004 distributionMapFn);
2005 if (llvm::is_contained(escapingValueDistTypes, Type{}))
2006 return failure();
2007 // `WarpOp` can yield two types of values:
2008 // 1. Values that are not results of the `ForOp`:
2009 // These values must also be yielded by the new `WarpOp`. Also, we need
2010 // to record the index mapping for these values to replace them later.
2011 // 2. Values that are results of the `ForOp`:
2012 // In this case, we record the index mapping between the `WarpOp` result
2013 // index and matching `ForOp` result index.
2014 // Additionally, we keep track of the distributed types for all `ForOp`
2015 // vector results.
2016 SmallVector<Value> nonForYieldedValues;
2017 SmallVector<unsigned> nonForResultIndices;
2018 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
2019 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
2020 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
2021 // Yielded value is not a result of the forOp.
2022 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
2023 nonForYieldedValues.push_back(yieldOperand.get());
2024 nonForResultIndices.push_back(yieldOperand.getOperandNumber());
2025 continue;
2026 }
2027 OpResult forResult = cast<OpResult>(yieldOperand.get());
2028 unsigned int forResultNumber = forResult.getResultNumber();
2029 forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber;
2030 // If this `ForOp` result is vector type and it is yielded by the
2031 // `WarpOp`, we keep track the distributed type for this result.
2032 if (!isa<VectorType>(forResult.getType()))
2033 continue;
2034 VectorType distType = cast<VectorType>(
2035 warpOp.getResult(yieldOperand.getOperandNumber()).getType());
2036 forResultDistTypes[forResultNumber] = distType;
2037 }
2038
2039 // Newly created `WarpOp` will yield values in following order:
2040 // 1. Loop bounds.
2041 // 2. All init args of the `ForOp`.
2042 // 3. All escaping values.
2043 // 4. All non-`ForOp` yielded values.
2044 SmallVector<Value> newWarpOpYieldValues;
2045 SmallVector<Type> newWarpOpDistTypes;
2046 newWarpOpYieldValues.insert(
2047 newWarpOpYieldValues.end(),
2048 {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
2049 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2050 {forOp.getLowerBound().getType(),
2051 forOp.getUpperBound().getType(),
2052 forOp.getStep().getType()});
2053 for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
2054 newWarpOpYieldValues.push_back(initArg);
2055 // Compute the distributed type for this init arg.
2056 Type distType = initArg.getType();
2057 if (auto vecType = dyn_cast<VectorType>(distType)) {
2058 // If the `ForOp` result corresponds to this init arg is already yielded
2059 // we can get the distributed type from `forResultDistTypes` map.
2060 // Otherwise, we compute it using distributionMapFn.
2061 AffineMap map = distributionMapFn(initArg);
2062 distType = forResultDistTypes.count(i)
2063 ? forResultDistTypes[i]
2064 : getDistributedType(vecType, map, warpOp.getWarpSize());
2065 }
2066 newWarpOpDistTypes.push_back(distType);
2067 }
2068 // Insert escaping values and their distributed types.
2069 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
2070 escapingValues.begin(), escapingValues.end());
2071 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2072 escapingValueDistTypes.begin(),
2073 escapingValueDistTypes.end());
2074 // Next, we insert all non-`ForOp` yielded values and their distributed
2075 // types.
2076 for (auto [i, v] :
2077 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2078 newWarpOpYieldValues.push_back(v);
2079 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
2080 }
2081 // Create the new `WarpOp` with the updated yield values and types.
2082 SmallVector<size_t> newIndices;
2083 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2084 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
2085
2086 // Next, we create a new `ForOp` with the init args yielded by the new
2087 // `WarpOp`.
2088 const unsigned initArgsStartIdx = 3; // After loop bounds.
2089 const unsigned escapingValuesStartIdx =
2090 initArgsStartIdx +
2091 forOp.getInitArgs().size(); // `ForOp` init args are positioned before
2092 // escaping values in the new `WarpOp`.
2093 SmallVector<Value> newForOpOperands;
2094 for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
2095 newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
2096
2097 // Create a new `ForOp` outside the new `WarpOp` region.
2098 OpBuilder::InsertionGuard g(rewriter);
2099 rewriter.setInsertionPointAfter(newWarpOp);
2100 auto newForOp = scf::ForOp::create(
2101 rewriter, forOp.getLoc(),
2102 /**LowerBound=**/ newWarpOp.getResult(newIndices[0]),
2103 /**UpperBound=**/ newWarpOp.getResult(newIndices[1]),
2104 /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands,
2105 /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
2106 // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
2107 // newly created `ForOp`. This `WarpOp` will contain all ops that were
2108 // contained within the original `ForOp` body.
2109 rewriter.setInsertionPointToStart(newForOp.getBody());
2110
2111 SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
2112 newForOp.getRegionIterArgs().end());
2113 SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
2114 forOp.getResultTypes().end());
2115 // Escaping values are forwarded to the inner `WarpOp` as its (additional)
2116 // arguments. We keep track of the mapping between these values and their
2117 // argument index in the inner `WarpOp` (to replace users later).
2118 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
2119 for (size_t i = escapingValuesStartIdx;
2120 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2121 innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
2122 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
2123 innerWarpInputType.size();
2124 innerWarpInputType.push_back(
2125 escapingValueInputTypes[i - escapingValuesStartIdx]);
2126 }
2127 // Create the inner `WarpOp` with the new input values and types.
2128 auto innerWarp = WarpExecuteOnLane0Op::create(
2129 rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
2130 newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
2131 innerWarpInputType);
2132
2133 // Inline the `ForOp` body into the inner `WarpOp` body.
2134 SmallVector<Value> argMapping;
2135 argMapping.push_back(newForOp.getInductionVar());
2136 for (Value args : innerWarp.getBody()->getArguments())
2137 argMapping.push_back(args);
2138
2139 argMapping.resize(forOp.getBody()->getNumArguments());
2140 SmallVector<Value> yieldOperands;
2141 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
2142 yieldOperands.push_back(operand);
2143
2144 rewriter.eraseOp(forOp.getBody()->getTerminator());
2145 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
2146
2147 // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
2148 // original `ForOp` results.
2149 rewriter.setInsertionPointToEnd(innerWarp.getBody());
2150 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
2151 rewriter.setInsertionPointAfter(innerWarp);
2152 // Insert a scf.yield op at the end of the new `ForOp` body that yields
2153 // the inner `WarpOp` results.
2154 if (!innerWarp.getResults().empty())
2155 scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
2156
2157 // Update the users of the new `WarpOp` results that were coming from the
2158 // original `ForOp` to the corresponding new `ForOp` result.
2159 for (auto [origIdx, newIdx] : forResultMapping)
2160 rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
2161 newForOp.getResult(newIdx), newForOp);
2162 // Update any users of escaping values that were forwarded to the
2163 // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
2164 newForOp.walk([&](Operation *op) {
2165 for (OpOperand &operand : op->getOpOperands()) {
2166 auto it = argIndexMapping.find(operand.get());
2167 if (it == argIndexMapping.end())
2168 continue;
2169 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
2170 }
2171 });
2172
2173 // Finally, hoist out any now uniform code from the inner `WarpOp`.
2174 mlir::vector::moveScalarUniformCode(innerWarp);
2175 return success();
2176 }
2177
2178private:
2179 DistributionMapFn distributionMapFn;
2180};
2181
2182/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
2183/// The vector is reduced in parallel. Currently limited to vector size
2184/// matching the warpOp size. E.g.:
2185/// ```
2186/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
2187/// %0 = "some_def"() : () -> (vector<32xf32>)
2188/// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
2189/// gpu.yield %1 : f32
2190/// }
2191/// ```
2192/// is lowered to:
2193/// ```
2194/// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
2195/// %1 = "some_def"() : () -> (vector<32xf32>)
2196/// gpu.yield %1 : vector<32xf32>
2197/// }
2198/// %a = vector.extract %0[0] : f32 from vector<1xf32>
2199/// %r = ("warp.reduction %a")
2200/// ```
2201struct WarpOpReduction : public WarpDistributionPattern {
2202 WarpOpReduction(MLIRContext *context,
2203 DistributedReductionFn distributedReductionFn,
2204 PatternBenefit benefit = 1)
2205 : WarpDistributionPattern(context, benefit),
2206 distributedReductionFn(std::move(distributedReductionFn)) {}
2207
2208 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2209 PatternRewriter &rewriter) const override {
2210 OpOperand *yieldOperand =
2211 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
2212 if (!yieldOperand)
2213 return failure();
2214
2215 auto reductionOp =
2216 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
2217 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
2218 // Only rank 1 vectors supported.
2219 if (vectorType.getRank() != 1)
2220 return rewriter.notifyMatchFailure(
2221 warpOp, "Only rank 1 reductions can be distributed.");
2222 // Only warp_size-sized vectors supported.
2223 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2224 return rewriter.notifyMatchFailure(
2225 warpOp, "Reduction vector dimension must match was size.");
2226 if (!reductionOp.getType().isIntOrFloat())
2227 return rewriter.notifyMatchFailure(
2228 warpOp, "Reduction distribution currently only supports floats and "
2229 "integer types.");
2230
2231 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2232 // Return vector that will be reduced from the WarpExecuteOnLane0Op.
2233 unsigned operandIndex = yieldOperand->getOperandNumber();
2234 SmallVector<Value> yieldValues = {reductionOp.getVector()};
2235 SmallVector<Type> retTypes = {
2236 VectorType::get({numElements}, reductionOp.getType())};
2237 if (reductionOp.getAcc()) {
2238 yieldValues.push_back(reductionOp.getAcc());
2239 retTypes.push_back(reductionOp.getAcc().getType());
2240 }
2241 SmallVector<size_t> newRetIndices;
2242 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2243 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2244 rewriter.setInsertionPointAfter(newWarpOp);
2245
2246 // Obtain data to reduce for a single lane.
2247 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2248 // Distribute and reduce across threads.
2249 Value fullReduce =
2250 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2251 reductionOp.getKind(), newWarpOp.getWarpSize());
2252 if (reductionOp.getAcc()) {
2253 fullReduce = vector::makeArithReduction(
2254 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2255 newWarpOp.getResult(newRetIndices[1]));
2256 }
2257 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
2258 return success();
2259 }
2260
2261private:
2262 DistributedReductionFn distributedReductionFn;
2263};
2264
2265} // namespace
2266
2272
2273void mlir::vector::populateDistributeTransferWriteOpPatterns(
2274 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2275 unsigned maxNumElementsToExtract, PatternBenefit benefit) {
2276 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
2277 maxNumElementsToExtract, benefit);
2278}
2279
2280void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2281 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2282 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
2283 PatternBenefit readBenefit) {
2284 patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
2285 patterns
2286 .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2287 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
2288 WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
2289 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2290 patterns.getContext(), benefit);
2291 patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
2292 benefit);
2293 patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
2294 benefit);
2295 patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
2296 benefit);
2297}
2298
2299void mlir::vector::populateDistributeReduction(
2301 const DistributedReductionFn &distributedReductionFn,
2302 PatternBenefit benefit) {
2303 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
2304 benefit);
2305}
2306
2307/// Helper to know if an op can be hoisted out of the region.
2308static bool canBeHoisted(Operation *op,
2309 function_ref<bool(Value)> definedOutside) {
2310 return llvm::all_of(op->getOperands(), definedOutside) &&
2311 isMemoryEffectFree(op) && op->getNumRegions() == 0;
2312}
2313
2314void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2315 Block *body = warpOp.getBody();
2316
2317 // Keep track of the ops we want to hoist.
2318 llvm::SmallSetVector<Operation *, 8> opsToMove;
2319
2320 // Helper to check if a value is or will be defined outside of the region.
2321 auto isDefinedOutsideOfBody = [&](Value value) {
2322 auto *definingOp = value.getDefiningOp();
2323 return (definingOp && opsToMove.count(definingOp)) ||
2324 warpOp.isDefinedOutsideOfRegion(value);
2325 };
2326
2327 // Do not use walk here, as we do not want to go into nested regions and hoist
2328 // operations from there.
2329 for (auto &op : body->without_terminator()) {
2330 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
2331 return isa<VectorType>(result.getType());
2332 });
2333 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
2334 opsToMove.insert(&op);
2335 }
2336
2337 // Move all the ops marked as uniform outside of the region.
2338 for (Operation *op : opsToMove)
2339 op->moveBefore(warpOp);
2340}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
unsigned getNumOperands()
Definition Operation.h:346
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.