MLIR 22.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1//===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/Attributes.h"
22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24#include "llvm/Support/FormatVariadic.h"
25#include <utility>
26
27using namespace mlir;
28using namespace mlir::vector;
29using namespace mlir::gpu;
30
31/// Currently the distribution map is implicit based on the vector shape. In the
32/// future it will be part of the op.
33/// Example:
34/// ```
35/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
36/// ...
37/// gpu.yield %3 : vector<32x16x64xf32>
38/// }
39/// ```
40/// Would have an implicit map of:
41/// `(d0, d1, d2) -> (d0, d2)`
42static AffineMap calculateImplicitMap(VectorType sequentialType,
43 VectorType distributedType) {
45 perm.reserve(1);
46 // Check which dimensions of the sequential type are different than the
47 // dimensions of the distributed type to know the distributed dimensions. Then
48 // associate each distributed dimension to an ID in order.
49 for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
51 perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
52 }
53 auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
54 distributedType.getContext());
55 return map;
56}
57
58/// Given a sequential and distributed vector type, returns the distributed
59/// dimension. This function expects that only a single dimension is
60/// distributed.
61static int getDistributedDim(VectorType sequentialType,
62 VectorType distributedType) {
63 assert(sequentialType.getRank() == distributedType.getRank() &&
64 "sequential and distributed vector types must have the same rank");
65 int64_t distributedDim = -1;
66 for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
67 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
68 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
69 // support distributing multiple dimensions in the future.
70 assert(distributedDim == -1 && "found multiple distributed dims");
71 distributedDim = i;
72 }
73 }
74 return distributedDim;
75}
76
77namespace {
78
79/// Helper struct to create the load / store operations that permit transit
80/// through the parallel / sequential and the sequential / parallel boundaries
81/// when performing `rewriteWarpOpToScfFor`.
82///
83/// The vector distribution dimension is inferred from the vector types.
84struct DistributedLoadStoreHelper {
85 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
86 Value laneId, Value zero)
87 : sequentialVal(sequentialVal), distributedVal(distributedVal),
88 laneId(laneId), zero(zero) {
89 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
90 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
91 if (sequentialVectorType && distributedVectorType)
92 distributionMap =
93 calculateImplicitMap(sequentialVectorType, distributedVectorType);
94 }
95
96 Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
97 int64_t distributedSize = distributedVectorType.getDimSize(index);
98 AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
99 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
100 ArrayRef<Value>{laneId});
101 }
102
103 /// Create a store during the process of distributing the
104 /// `vector.warp_execute_on_thread_0` op.
105 /// Vector distribution assumes the following convention regarding the
106 /// temporary buffers that are created to transition values. This **must**
107 /// be properly specified in the `options.warpAllocationFn`:
108 /// 1. scalars of type T transit through a memref<1xT>.
109 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
110 Operation *buildStore(RewriterBase &b, Location loc, Value val,
111 Value buffer) {
112 assert((val == distributedVal || val == sequentialVal) &&
113 "Must store either the preregistered distributed or the "
114 "preregistered sequential value.");
115 // Scalar case can directly use memref.store.
116 if (!isa<VectorType>(val.getType()))
117 return memref::StoreOp::create(b, loc, val, buffer, zero);
118
119 // Vector case must use vector::TransferWriteOp which will later lower to
120 // vector.store of memref.store depending on further lowerings.
121 int64_t rank = sequentialVectorType.getRank();
122 SmallVector<Value> indices(rank, zero);
123 if (val == distributedVal) {
124 for (auto dimExpr : distributionMap.getResults()) {
125 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126 indices[index] = buildDistributedOffset(b, loc, index);
127 }
128 }
129 SmallVector<bool> inBounds(indices.size(), true);
130 return vector::TransferWriteOp::create(
131 b, loc, val, buffer, indices,
132 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
133 }
134
135 /// Create a load during the process of distributing the
136 /// `vector.warp_execute_on_thread_0` op.
137 /// Vector distribution assumes the following convention regarding the
138 /// temporary buffers that are created to transition values. This **must**
139 /// be properly specified in the `options.warpAllocationFn`:
140 /// 1. scalars of type T transit through a memref<1xT>.
141 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
142 ///
143 /// When broadcastMode is true, the load is not distributed to account for
144 /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
145 ///
146 /// Example:
147 ///
148 /// ```
149 /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
150 /// gpu.yield %cst : f32
151 /// }
152 /// // Both types are f32. The constant %cst is broadcasted to all lanes.
153 /// ```
154 /// This behavior described in more detail in the documentation of the op.
155 Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
156
157 // Scalar case can directly use memref.store.
158 if (!isa<VectorType>(type))
159 return memref::LoadOp::create(b, loc, buffer, zero);
160
161 // Other cases must be vector atm.
162 // Vector case must use vector::TransferReadOp which will later lower to
163 // vector.read of memref.read depending on further lowerings.
164 assert((type == distributedVectorType || type == sequentialVectorType) &&
165 "Must store either the preregistered distributed or the "
166 "preregistered sequential type.");
167 SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
168 if (type == distributedVectorType) {
169 for (auto dimExpr : distributionMap.getResults()) {
170 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171 indices[index] = buildDistributedOffset(b, loc, index);
172 }
173 }
174 SmallVector<bool> inBounds(indices.size(), true);
175 return vector::TransferReadOp::create(
176 b, loc, cast<VectorType>(type), buffer, indices,
177 /*padding=*/std::nullopt,
178 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
179 }
180
181 Value sequentialVal, distributedVal, laneId, zero;
182 VectorType sequentialVectorType, distributedVectorType;
183 AffineMap distributionMap;
184};
185
186} // namespace
187
188// Clones `op` into a new operation that takes `operands` and returns
189// `resultTypes`.
191 Location loc, Operation *op,
192 ArrayRef<Value> operands,
193 ArrayRef<Type> resultTypes) {
194 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
195 op->getAttrs());
196 return rewriter.create(res);
197}
198
199namespace {
200
201/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
202/// thread `laneId` executes the entirety of the computation.
203///
204/// After the transformation:
205/// - the IR within the scf.if op can be thought of as executing sequentially
206/// (from the point of view of threads along `laneId`).
207/// - the IR outside of the scf.if op can be thought of as executing in
208/// parallel (from the point of view of threads along `laneId`).
209///
210/// Values that need to transit through the parallel / sequential and the
211/// sequential / parallel boundaries do so via reads and writes to a temporary
212/// memory location.
213///
214/// The transformation proceeds in multiple steps:
215/// 1. Create the scf.if op.
216/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
217/// within the scf.if to transit the values captured from above.
218/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
219/// consistent within the scf.if.
220/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
221/// 5. Insert appropriate writes within scf.if and reads after the scf.if to
222/// transit the values returned by the op.
223/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
224/// consistent after the scf.if.
225/// 7. Perform late cleanups.
226///
227/// All this assumes the vector distribution occurs along the most minor
228/// distributed vector dimension.
229struct WarpOpToScfIfPattern : public WarpDistributionPattern {
230 WarpOpToScfIfPattern(MLIRContext *context,
231 const WarpExecuteOnLane0LoweringOptions &options,
232 PatternBenefit benefit = 1)
233 : WarpDistributionPattern(context, benefit), options(options) {}
234
235 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
236 PatternRewriter &rewriter) const override {
237 assert(warpOp.getBodyRegion().hasOneBlock() &&
238 "expected WarpOp with single block");
239 Block *warpOpBody = &warpOp.getBodyRegion().front();
240 Location loc = warpOp.getLoc();
241
242 // Passed all checks. Start rewriting.
243 OpBuilder::InsertionGuard g(rewriter);
244 rewriter.setInsertionPoint(warpOp);
245
246 // Step 1: Create scf.if op.
247 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
248 Value isLane0 = arith::CmpIOp::create(
249 rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
250 auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
251 /*withElseRegion=*/false);
252 rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
253
254 // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
255 // reads within the scf.if to transit the values captured from above.
256 SmallVector<Value> bbArgReplacements;
257 for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
258 Value sequentialVal = warpOpBody->getArgument(it.index());
259 Value distributedVal = it.value();
260 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
261 warpOp.getLaneid(), c0);
262
263 // Create buffer before the ifOp.
264 rewriter.setInsertionPoint(ifOp);
265 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
266 sequentialVal.getType());
267 // Store distributed vector into buffer, before the ifOp.
268 helper.buildStore(rewriter, loc, distributedVal, buffer);
269 // Load sequential vector from buffer, inside the ifOp.
270 rewriter.setInsertionPointToStart(ifOp.thenBlock());
271 bbArgReplacements.push_back(
272 helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
273 }
274
275 // Step 3. Insert sync after all the stores and before all the loads.
276 if (!warpOp.getArgs().empty()) {
277 rewriter.setInsertionPoint(ifOp);
278 options.warpSyncronizationFn(loc, rewriter, warpOp);
279 }
280
281 // Step 4. Move body of warpOp to ifOp.
282 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
283
284 // Step 5. Insert appropriate writes within scf.if and reads after the
285 // scf.if to transit the values returned by the op.
286 // TODO: at this point, we can reuse the shared memory from previous
287 // buffers.
288 SmallVector<Value> replacements;
289 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
290 Location yieldLoc = yieldOp.getLoc();
291 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
292 Value sequentialVal = it.value();
293 Value distributedVal = warpOp->getResult(it.index());
294 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
295 warpOp.getLaneid(), c0);
296
297 // Create buffer before the ifOp.
298 rewriter.setInsertionPoint(ifOp);
299 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
300 sequentialVal.getType());
301
302 // Store yielded value into buffer, inside the ifOp, before the
303 // terminator.
304 rewriter.setInsertionPoint(yieldOp);
305 helper.buildStore(rewriter, loc, sequentialVal, buffer);
306
307 // Load distributed value from buffer, after the warpOp.
308 rewriter.setInsertionPointAfter(ifOp);
309 // Result type and yielded value type are the same. This is a broadcast.
310 // E.g.:
311 // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
312 // gpu.yield %cst : f32
313 // }
314 // Both types are f32. The constant %cst is broadcasted to all lanes.
315 // This is described in more detail in the documentation of the op.
316 replacements.push_back(
317 helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
318 }
319
320 // Step 6. Insert sync after all the stores and before all the loads.
321 if (!yieldOp.getOperands().empty()) {
322 rewriter.setInsertionPointAfter(ifOp);
323 options.warpSyncronizationFn(loc, rewriter, warpOp);
324 }
325
326 // Step 7. Delete terminator and add empty scf.yield.
327 rewriter.eraseOp(yieldOp);
328 rewriter.setInsertionPointToEnd(ifOp.thenBlock());
329 scf::YieldOp::create(rewriter, yieldLoc);
330
331 // Compute replacements for WarpOp results.
332 rewriter.replaceOp(warpOp, replacements);
333
334 return success();
335 }
336
337private:
338 const WarpExecuteOnLane0LoweringOptions &options;
339};
340
341/// Return the distributed vector type based on the original type and the
342/// distribution map. The map is expected to have a dimension equal to the
343/// original type rank and should be a projection where the results are the
344/// distributed dimensions. If the number of results is zero there is no
345/// distribution (i.e. original type is returned).
346/// Otherwise, The number of results should be equal to the number
347/// of warp sizes which is currently limited to 1.
348/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
349/// and a warp size of 16 would distribute the second dimension (associated to
350/// d1) and return vector<16x2x64>
351static VectorType getDistributedType(VectorType originalType, AffineMap map,
352 int64_t warpSize) {
353 // If the map has zero results, return the original type.
354 if (map.getNumResults() == 0)
355 return originalType;
356 SmallVector<int64_t> targetShape(originalType.getShape());
357 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
358 unsigned position = map.getDimPosition(i);
359 if (targetShape[position] % warpSize != 0) {
360 if (warpSize % targetShape[position] != 0) {
361 return VectorType();
362 }
363 warpSize /= targetShape[position];
364 targetShape[position] = 1;
365 continue;
366 }
367 targetShape[position] = targetShape[position] / warpSize;
368 warpSize = 1;
369 break;
370 }
371 if (warpSize != 1) {
372 return VectorType();
373 }
374 VectorType targetType =
375 VectorType::get(targetShape, originalType.getElementType());
376 return targetType;
377}
378
379/// Given a warpOp that contains ops with regions, the corresponding op's
380/// "inner" region and the distributionMapFn, get all values used by the op's
381/// region that are defined within the warpOp, but outside the inner region.
382/// Return the set of values, their types and their distributed types.
383std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
385getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
386 DistributionMapFn distributionMapFn) {
387 llvm::SmallSetVector<Value, 32> escapingValues;
388 SmallVector<Type> escapingValueTypes;
389 SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
390 if (innerRegion.empty())
391 return {std::move(escapingValues), std::move(escapingValueTypes),
392 std::move(escapingValueDistTypes)};
393 mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
394 Operation *parent = operand->get().getParentRegion()->getParentOp();
395 if (warpOp->isAncestor(parent)) {
396 if (!escapingValues.insert(operand->get()))
397 return;
398 Type distType = operand->get().getType();
399 if (auto vecType = dyn_cast<VectorType>(distType)) {
400 AffineMap map = distributionMapFn(operand->get());
401 distType = getDistributedType(vecType, map, warpOp.getWarpSize());
402 }
403 escapingValueTypes.push_back(operand->get().getType());
404 escapingValueDistTypes.push_back(distType);
405 }
406 });
407 return {std::move(escapingValues), std::move(escapingValueTypes),
408 std::move(escapingValueDistTypes)};
409}
410
411/// Distribute transfer_write ops based on the affine map returned by
412/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
413/// will not be distributed (it should be less than the warp size).
414///
415/// Example:
416/// ```
417/// %0 = gpu.warp_execute_on_lane_0(%id){
418/// ...
419/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
420/// gpu.yield
421/// }
422/// ```
423/// To
424/// ```
425/// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
426/// ...
427/// gpu.yield %v : vector<32xf32>
428/// }
429/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
430struct WarpOpTransferWrite : public WarpDistributionPattern {
431 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
432 unsigned maxNumElementsToExtract, PatternBenefit b = 1)
433 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
434 maxNumElementsToExtract(maxNumElementsToExtract) {}
435
436 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
437 /// are multiples of the distribution ratio are supported at the moment.
438 LogicalResult tryDistributeOp(RewriterBase &rewriter,
439 vector::TransferWriteOp writeOp,
440 WarpExecuteOnLane0Op warpOp) const {
441 VectorType writtenVectorType = writeOp.getVectorType();
442
443 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
444 // to separate it from the rest.
445 if (writtenVectorType.getRank() == 0)
446 return failure();
447
448 // 2. Compute the distributed type.
449 AffineMap map = distributionMapFn(writeOp.getVector());
450 VectorType targetType =
451 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
452 if (!targetType)
453 return failure();
454
455 // 2.5 Compute the distributed type for the new mask;
456 VectorType maskType;
457 if (writeOp.getMask()) {
458 // TODO: Distribution of masked writes with non-trivial permutation maps
459 // requires the distribution of the mask to elementwise match the
460 // distribution of the permuted written vector. Currently the details
461 // of which lane is responsible for which element is captured strictly
462 // by shape information on the warp op, and thus requires materializing
463 // the permutation in IR.
464 if (!writeOp.getPermutationMap().isMinorIdentity())
465 return failure();
466 maskType =
467 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
468 }
469
470 // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
471 // the rest.
472 vector::TransferWriteOp newWriteOp =
473 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
474
475 // 4. Reindex the write using the distribution map.
476 auto newWarpOp =
477 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
478
479 // Delinearize the lane id based on the way threads are divided across the
480 // vector. To get the number of threads per vector dimension, divide the
481 // sequential size by the distributed size along each dim.
482 rewriter.setInsertionPoint(newWriteOp);
483 SmallVector<OpFoldResult> delinearizedIdSizes;
484 for (auto [seqSize, distSize] :
485 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
486 assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
487 delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
488 }
489 SmallVector<Value> delinearized;
490 if (map.getNumResults() > 1) {
491 delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
492 rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
493 delinearizedIdSizes)
494 .getResults();
495 } else {
496 // If there is only one map result, we can elide the delinearization
497 // op and use the lane id directly.
498 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
499 }
500
501 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
502 Location loc = newWriteOp.getLoc();
503 SmallVector<Value> indices(newWriteOp.getIndices().begin(),
504 newWriteOp.getIndices().end());
505 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
506 AffineExpr d0, d1;
507 bindDims(newWarpOp.getContext(), d0, d1);
508 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
509 if (!indexExpr)
510 continue;
511 unsigned indexPos = indexExpr.getPosition();
512 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
513 Value laneId = delinearized[vectorPos];
514 auto scale =
515 rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
517 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
518 }
519 newWriteOp.getIndicesMutable().assign(indices);
520
521 return success();
522 }
523
524 /// Extract TransferWriteOps of vector<1x> into a separate warp op.
525 LogicalResult tryExtractOp(RewriterBase &rewriter,
526 vector::TransferWriteOp writeOp,
527 WarpExecuteOnLane0Op warpOp) const {
528 Location loc = writeOp.getLoc();
529 VectorType vecType = writeOp.getVectorType();
530
531 if (vecType.getNumElements() > maxNumElementsToExtract) {
532 return rewriter.notifyMatchFailure(
533 warpOp,
534 llvm::formatv(
535 "writes more elements ({0}) than allowed to extract ({1})",
536 vecType.getNumElements(), maxNumElementsToExtract));
537 }
538
539 // Do not process warp ops that contain only TransferWriteOps.
540 if (llvm::all_of(warpOp.getOps(),
541 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
542 return failure();
543
544 SmallVector<Value> yieldValues = {writeOp.getVector()};
545 SmallVector<Type> retTypes = {vecType};
546 SmallVector<size_t> newRetIndices;
547 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
548 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
549 rewriter.setInsertionPointAfter(newWarpOp);
550
551 // Create a second warp op that contains only writeOp.
552 auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc, TypeRange(),
553 newWarpOp.getLaneid(),
554 newWarpOp.getWarpSize());
555 Block &body = secondWarpOp.getBodyRegion().front();
556 rewriter.setInsertionPointToStart(&body);
557 auto newWriteOp =
558 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
559 newWriteOp.getValueToStoreMutable().assign(
560 newWarpOp.getResult(newRetIndices[0]));
561 rewriter.eraseOp(writeOp);
562 gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
563 return success();
564 }
565
566 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
567 PatternRewriter &rewriter) const override {
568 gpu::YieldOp yield = warpOp.getTerminator();
569 Operation *lastNode = yield->getPrevNode();
570 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
571 if (!writeOp)
572 return failure();
573
574 Value maybeMask = writeOp.getMask();
575 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
576 return writeOp.getVector() == value ||
577 (maybeMask && maybeMask == value) ||
578 warpOp.isDefinedOutsideOfRegion(value);
579 }))
580 return failure();
581
582 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
583 return success();
584
585 // Masked writes not supported for extraction.
586 if (writeOp.getMask())
587 return failure();
588
589 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
590 return success();
591
592 return failure();
593 }
594
595private:
596 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
597 /// execute op with the proper return type. The new write op is updated to
598 /// write the result of the new warp execute op. The old `writeOp` is deleted.
599 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
600 WarpExecuteOnLane0Op warpOp,
601 vector::TransferWriteOp writeOp,
602 VectorType targetType,
603 VectorType maybeMaskType) const {
604 assert(writeOp->getParentOp() == warpOp &&
605 "write must be nested immediately under warp");
606 OpBuilder::InsertionGuard g(rewriter);
607 SmallVector<size_t> newRetIndices;
608 WarpExecuteOnLane0Op newWarpOp;
609 if (maybeMaskType) {
611 rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
612 TypeRange{targetType, maybeMaskType}, newRetIndices);
613 } else {
615 rewriter, warpOp, ValueRange{{writeOp.getVector()}},
616 TypeRange{targetType}, newRetIndices);
617 }
618 rewriter.setInsertionPointAfter(newWarpOp);
619 auto newWriteOp =
620 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
621 rewriter.eraseOp(writeOp);
622 newWriteOp.getValueToStoreMutable().assign(
623 newWarpOp.getResult(newRetIndices[0]));
624 if (maybeMaskType)
625 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
626 return newWriteOp;
627 }
628
629 DistributionMapFn distributionMapFn;
630 unsigned maxNumElementsToExtract = 1;
631};
632
633/// Sink out elementwise op feeding into a warp op yield.
634/// ```
635/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
636/// ...
637/// %3 = arith.addf %1, %2 : vector<32xf32>
638/// gpu.yield %3 : vector<32xf32>
639/// }
640/// ```
641/// To
642/// ```
643/// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
644/// vector<1xf32>, vector<1xf32>) {
645/// ...
646/// %4 = arith.addf %2, %3 : vector<32xf32>
647/// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
648/// vector<32xf32>
649/// }
650/// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
651struct WarpOpElementwise : public WarpDistributionPattern {
652 using Base::Base;
653 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
654 PatternRewriter &rewriter) const override {
655 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
657 });
658 if (!yieldOperand)
659 return failure();
660
661 Operation *elementWise = yieldOperand->get().getDefiningOp();
662 unsigned operandIndex = yieldOperand->getOperandNumber();
663 Value distributedVal = warpOp.getResult(operandIndex);
664 SmallVector<Value> yieldValues;
665 SmallVector<Type> retTypes;
666 Location loc = warpOp.getLoc();
667 for (OpOperand &operand : elementWise->getOpOperands()) {
668 Type targetType;
669 if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
670 // If the result type is a vector, the operands must also be vectors.
671 auto operandType = cast<VectorType>(operand.get().getType());
672 targetType =
673 VectorType::get(vecType.getShape(), operandType.getElementType());
674 } else {
675 auto operandType = operand.get().getType();
676 assert(!isa<VectorType>(operandType) &&
677 "unexpected yield of vector from op with scalar result type");
678 targetType = operandType;
679 }
680 retTypes.push_back(targetType);
681 yieldValues.push_back(operand.get());
682 }
683 SmallVector<size_t> newRetIndices;
684 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
685 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
686 rewriter.setInsertionPointAfter(newWarpOp);
687 SmallVector<Value> newOperands(elementWise->getOperands().begin(),
688 elementWise->getOperands().end());
689 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
690 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
691 }
692 OpBuilder::InsertionGuard g(rewriter);
693 rewriter.setInsertionPointAfter(newWarpOp);
694 Operation *newOp = cloneOpWithOperandsAndTypes(
695 rewriter, loc, elementWise, newOperands,
696 {newWarpOp.getResult(operandIndex).getType()});
697 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
698 newOp->getResult(0));
699 return success();
700 }
701};
702
703/// Sink out splat constant op feeding into a warp op yield.
704/// ```
705/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
706/// ...
707/// %cst = arith.constant dense<2.0> : vector<32xf32>
708/// gpu.yield %cst : vector<32xf32>
709/// }
710/// ```
711/// To
712/// ```
713/// gpu.warp_execute_on_lane_0(%arg0 {
714/// ...
715/// }
716/// %0 = arith.constant dense<2.0> : vector<1xf32>
717struct WarpOpConstant : public WarpDistributionPattern {
718 using Base::Base;
719 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
720 PatternRewriter &rewriter) const override {
721 OpOperand *yieldOperand =
722 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
723 if (!yieldOperand)
724 return failure();
725 auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
726 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
727 if (!dense)
728 return failure();
729 // Notify the rewriter that the warp op is changing (see the comment on
730 // the WarpOpTransferRead pattern).
731 rewriter.startOpModification(warpOp);
732 unsigned operandIndex = yieldOperand->getOperandNumber();
733 Attribute scalarAttr = dense.getSplatValue<Attribute>();
734 auto newAttr = DenseElementsAttr::get(
735 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
736 Location loc = warpOp.getLoc();
737 rewriter.setInsertionPointAfter(warpOp);
738 Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
739 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
740 rewriter.finalizeOpModification(warpOp);
741 return success();
742 }
743};
744
745/// Sink out step op feeding into a warp op yield.
746/// Vector step op is treated similar to arith.constant, apart from
747/// the result that represents a sequence [0, vec_size).
748/// Due to the to vec_size == warp_size limitation,
749/// we can simply wrap the lane id into a vector (i.e., broadcast).
750/// Supporting vec_size != warp_size may involve preserving the step
751/// result and using additional arith ops (the exact details are TBD).
752/// ```
753/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
754/// ...
755/// %cst = vector.step : vector<32xindex>
756/// gpu.yield %cst : vector<1xindex>
757/// }
758/// ```
759/// To
760/// ```
761/// gpu.warp_execute_on_lane_0(%arg0) {
762/// ...
763/// }
764/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
765struct WarpOpStep final : public WarpDistributionPattern {
766 using Base::Base;
767 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
768 PatternRewriter &rewriter) const override {
769 OpOperand *yieldOperand =
770 getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
771 if (!yieldOperand)
772 return failure();
773 const unsigned operandIdx = yieldOperand->getOperandNumber();
774 auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
775 VectorType resTy = stepOp.getResult().getType();
776 if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize()))
777 return rewriter.notifyMatchFailure(
778 warpOp,
779 llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
780 resTy.getNumElements(), warpOp.getWarpSize()));
781 VectorType newVecTy =
782 cast<VectorType>(warpOp.getResult(operandIdx).getType());
783 rewriter.setInsertionPointAfter(warpOp);
784 Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
785 newVecTy, warpOp.getLaneid());
786 rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
787 return success();
788 }
789};
790
791/// Sink out transfer_read op feeding into a warp op yield.
792/// ```
793/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
794/// ...
795// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
796// vector<32xf32>
797/// gpu.yield %2 : vector<32xf32>
798/// }
799/// ```
800/// To
801/// ```
802/// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
803/// vector<1xf32>, vector<1xf32>) {
804/// ...
805/// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
806/// vector<32xf32> gpu.yield %2 : vector<32xf32>
807/// }
808/// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
809struct WarpOpTransferRead : public WarpDistributionPattern {
810 using Base::Base;
811 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
812 PatternRewriter &rewriter) const override {
813 // Try to find a distributable yielded read. Note that this pattern can
814 // still fail at the end after distribution, in which case this might have
815 // missed another distributable read.
816 OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
817 // Don't duplicate transfer_read ops when distributing.
818 return isa<vector::TransferReadOp>(op) && op->hasOneUse();
819 });
820 if (!operand)
821 return rewriter.notifyMatchFailure(
822 warpOp, "warp result is not a vector.transfer_read op");
823 auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
824
825 // Source must be defined outside of the region.
826 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
827 return rewriter.notifyMatchFailure(
828 read, "source must be defined outside of the region");
829
830 unsigned operandIndex = operand->getOperandNumber();
831 Value distributedVal = warpOp.getResult(operandIndex);
832
833 SmallVector<Value, 4> indices(read.getIndices().begin(),
834 read.getIndices().end());
835 auto sequentialType = cast<VectorType>(read.getResult().getType());
836 auto distributedType = cast<VectorType>(distributedVal.getType());
837 AffineMap map = calculateImplicitMap(sequentialType, distributedType);
838 AffineMap indexMap = map.compose(read.getPermutationMap());
839
840 // Try to delinearize the lane ID to match the rank expected for
841 // distribution.
842 SmallVector<Value> delinearizedIds;
843 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
844 distributedType.getShape(), warpOp.getWarpSize(),
845 warpOp.getLaneid(), delinearizedIds)) {
846 return rewriter.notifyMatchFailure(
847 read, "cannot delinearize lane ID for distribution");
848 }
849 assert(!delinearizedIds.empty() || map.getNumResults() == 0);
850
851 // Distribute indices and the mask (if present).
852 OpBuilder::InsertionGuard g(rewriter);
853 SmallVector<Value> additionalResults(indices.begin(), indices.end());
854 SmallVector<Type> additionalResultTypes(indices.size(),
855 rewriter.getIndexType());
856 additionalResults.push_back(read.getPadding());
857 additionalResultTypes.push_back(read.getPadding().getType());
858
859 bool hasMask = false;
860 if (read.getMask()) {
861 hasMask = true;
862 // TODO: Distribution of masked reads with non-trivial permutation maps
863 // requires the distribution of the mask to elementwise match the
864 // distribution of the permuted written vector. Currently the details
865 // of which lane is responsible for which element is captured strictly
866 // by shape information on the warp op, and thus requires materializing
867 // the permutation in IR.
868 if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
869 return rewriter.notifyMatchFailure(
870 read, "non-trivial permutation maps not supported");
871 VectorType maskType =
872 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
873 additionalResults.push_back(read.getMask());
874 additionalResultTypes.push_back(maskType);
875 }
876
877 SmallVector<size_t> newRetIndices;
878 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
879 rewriter, warpOp, additionalResults, additionalResultTypes,
880 newRetIndices);
881 distributedVal = newWarpOp.getResult(operandIndex);
882
883 // Distributed indices were appended first.
884 SmallVector<Value> newIndices;
885 for (int64_t i = 0, e = indices.size(); i < e; ++i)
886 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
887
888 rewriter.setInsertionPointAfter(newWarpOp);
889 for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
890 AffineExpr d0, d1;
891 bindDims(read.getContext(), d0, d1);
892 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
893 if (!indexExpr)
894 continue;
895 unsigned indexPos = indexExpr.getPosition();
896 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
897 int64_t scale = distributedType.getDimSize(vectorPos);
898 newIndices[indexPos] = affine::makeComposedAffineApply(
899 rewriter, read.getLoc(), d0 + scale * d1,
900 {newIndices[indexPos], delinearizedIds[vectorPos]});
901 }
902
903 // Distributed padding value was appended right after the indices.
904 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
905 // Distributed mask value was added at the end (if the op has a mask).
906 Value newMask =
907 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
908 : Value();
909 auto newRead = vector::TransferReadOp::create(
910 rewriter, read.getLoc(), distributedVal.getType(), read.getBase(),
911 newIndices, read.getPermutationMapAttr(), newPadding, newMask,
912 read.getInBoundsAttr());
913
914 rewriter.replaceAllUsesWith(distributedVal, newRead);
915 return success();
916 }
917};
918
919/// Remove any result that has no use along with the matching yieldOp operand.
920// TODO: Move this in WarpExecuteOnLane0Op canonicalization.
921struct WarpOpDeadResult : public WarpDistributionPattern {
922 using Base::Base;
923 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
924 PatternRewriter &rewriter) const override {
925 SmallVector<Type> newResultTypes;
926 newResultTypes.reserve(warpOp->getNumResults());
927 SmallVector<Value> newYieldValues;
928 newYieldValues.reserve(warpOp->getNumResults());
929 DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
930 DenseMap<OpResult, int64_t> dedupResultPositionMap;
931 gpu::YieldOp yield = warpOp.getTerminator();
932
933 // Some values may be yielded multiple times and correspond to multiple
934 // results. Deduplicating occurs by taking each result with its matching
935 // yielded value, and:
936 // 1. recording the unique first position at which the value with uses is
937 // yielded.
938 // 2. recording for the result, the first position at which the dedup'ed
939 // value is yielded.
940 // 3. skipping from the new result types / new yielded values any result
941 // that has no use or whose yielded value has already been seen.
942 for (OpResult result : warpOp.getResults()) {
943 if (result.use_empty())
944 continue;
945 Value yieldOperand = yield.getOperand(result.getResultNumber());
946 auto it = dedupYieldOperandPositionMap.insert(
947 std::make_pair(yieldOperand, newResultTypes.size()));
948 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
949 if (!it.second)
950 continue;
951 newResultTypes.push_back(result.getType());
952 newYieldValues.push_back(yieldOperand);
953 }
954 // No modification, exit early.
955 if (yield.getNumOperands() == newYieldValues.size())
956 return failure();
957 // Move the body of the old warpOp to a new warpOp.
958 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
959 rewriter, warpOp, newYieldValues, newResultTypes);
960
961 // Simplify the new warp op after dropping dead results.
962 newWarpOp.getBody()->walk([&](Operation *op) {
963 if (isOpTriviallyDead(op))
964 rewriter.eraseOp(op);
965 });
966
967 // Replace results of the old warpOp by the new, deduplicated results.
968 SmallVector<Value> newValues;
969 newValues.reserve(warpOp->getNumResults());
970 for (OpResult result : warpOp.getResults()) {
971 if (result.use_empty())
972 newValues.push_back(Value());
973 else
974 newValues.push_back(
975 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
976 }
977 rewriter.replaceOp(warpOp, newValues);
978 return success();
979 }
980};
981
982// If an operand is directly yielded out of the region we can forward it
983// directly and it doesn't need to go through the region.
984struct WarpOpForwardOperand : public WarpDistributionPattern {
985 using Base::Base;
986 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
987 PatternRewriter &rewriter) const override {
988 gpu::YieldOp yield = warpOp.getTerminator();
989 Value valForwarded;
990 unsigned resultIndex;
991 for (OpOperand &operand : yield->getOpOperands()) {
992 Value result = warpOp.getResult(operand.getOperandNumber());
993 if (result.use_empty())
994 continue;
995
996 // Assume all the values coming from above are uniform.
997 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
998 if (result.getType() != operand.get().getType())
999 continue;
1000 valForwarded = operand.get();
1001 resultIndex = operand.getOperandNumber();
1002 break;
1003 }
1004 auto arg = dyn_cast<BlockArgument>(operand.get());
1005 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1006 continue;
1007 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1008 if (result.getType() != warpOperand.getType())
1009 continue;
1010 valForwarded = warpOperand;
1011 resultIndex = operand.getOperandNumber();
1012 break;
1013 }
1014 if (!valForwarded)
1015 return failure();
1016 // Notify the rewriter that the warp op is changing (see the comment on
1017 // the WarpOpTransferRead pattern).
1018 rewriter.startOpModification(warpOp);
1019 rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
1020 rewriter.finalizeOpModification(warpOp);
1021 return success();
1022 }
1023};
1024
1025struct WarpOpBroadcast : public WarpDistributionPattern {
1026 using Base::Base;
1027 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1028 PatternRewriter &rewriter) const override {
1029 OpOperand *operand =
1030 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1031 if (!operand)
1032 return failure();
1033 unsigned int operandNumber = operand->getOperandNumber();
1034 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
1035 Location loc = broadcastOp.getLoc();
1036 auto destVecType =
1037 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1038 Value broadcastSrc = broadcastOp.getSource();
1039 Type broadcastSrcType = broadcastSrc.getType();
1040
1041 // Check that the broadcast actually spans a set of values uniformly across
1042 // all threads. In other words, check that each thread can reconstruct
1043 // their own broadcast.
1044 // For that we simply check that the broadcast we want to build makes sense.
1045 if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
1046 vector::BroadcastableToResult::Success)
1047 return failure();
1048 SmallVector<size_t> newRetIndices;
1049 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1050 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1051 rewriter.setInsertionPointAfter(newWarpOp);
1052 Value broadcasted = vector::BroadcastOp::create(
1053 rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1054 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1055 broadcasted);
1056 return success();
1057 }
1058};
1059
1060/// Pattern to move shape cast out of the warp op. shape cast is basically a
1061/// no-op for warp distribution; we need to handle the shape though.
1062struct WarpOpShapeCast : public WarpDistributionPattern {
1063 using Base::Base;
1064 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1065 PatternRewriter &rewriter) const override {
1066 OpOperand *operand =
1067 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1068 if (!operand)
1069 return failure();
1070
1071 auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
1072
1073 unsigned int operandNumber = operand->getOperandNumber();
1074 auto castDistributedType =
1075 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1076 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1077 VectorType castResultType = castDistributedType;
1078
1079 // We expect the distributed type to have a smaller rank than the original
1080 // type. Prepend with size-one dimensions to make them the same.
1081 unsigned castDistributedRank = castDistributedType.getRank();
1082 unsigned castOriginalRank = castOriginalType.getRank();
1083 if (castDistributedRank < castOriginalRank) {
1084 SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1085 llvm::append_range(shape, castDistributedType.getShape());
1086 castDistributedType =
1087 VectorType::get(shape, castDistributedType.getElementType());
1088 }
1089
1090 SmallVector<size_t> newRetIndices;
1091 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1092 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1093 newRetIndices);
1094 rewriter.setInsertionPointAfter(newWarpOp);
1095 Value newCast = vector::ShapeCastOp::create(
1096 rewriter, oldCastOp.getLoc(), castResultType,
1097 newWarpOp->getResult(newRetIndices[0]));
1098 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1099 return success();
1100 }
1101};
1102
1103/// Sink out vector.create_mask / vector.constant_mask op feeding into a warp op
1104/// yield.
1105/// ```
1106/// %0 = ...
1107/// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1108/// ...
1109/// %mask = vector.create_mask %0 : vector<32xi1>
1110/// // or %mask = vector.constant_mask[2] : vector<32xi1>
1111/// gpu.yield %mask : vector<32xi1>
1112/// }
1113/// ```
1114/// To
1115/// ```
1116/// %0 = ...
1117/// gpu.warp_execute_on_lane_0(%arg0) {
1118/// ...
1119/// }
1120/// %cmp = arith.cmpi ult, %laneid, %0
1121/// %ub = arith.select %cmp, %c0, %c1
1122/// %1 = vector.create_mask %ub : vector<1xi1>
1123template <typename OpType,
1124 typename = std::enable_if_t<llvm::is_one_of<
1125 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
1126struct WarpOpCreateMask : public WarpDistributionPattern {
1127 using Base::Base;
1128 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1129 PatternRewriter &rewriter) const override {
1130 OpOperand *yieldOperand = getWarpResult(warpOp, (llvm::IsaPred<OpType>));
1131 if (!yieldOperand)
1132 return failure();
1133
1134 Operation *mask = yieldOperand->get().getDefiningOp<OpType>();
1135
1136 // Early exit if any values needed for calculating the new mask indices
1137 // are defined inside the warp op.
1138 if (mask->getOperands().size() &&
1139 !llvm::all_of(mask->getOperands(), [&](Value value) {
1140 return warpOp.isDefinedOutsideOfRegion(value);
1141 }))
1142 return failure();
1143
1144 Location loc = mask->getLoc();
1145 unsigned operandIndex = yieldOperand->getOperandNumber();
1146
1147 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1148 VectorType seqType = cast<VectorType>(mask->getResult(0).getType());
1149 ArrayRef<int64_t> seqShape = seqType.getShape();
1150 ArrayRef<int64_t> distShape = distType.getShape();
1151 SmallVector<Value> materializedOperands;
1152 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
1153 materializedOperands.append(mask->getOperands().begin(),
1154 mask->getOperands().end());
1155 } else {
1156 auto constantMaskOp = cast<vector::ConstantMaskOp>(mask);
1157 auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
1158 for (auto dimSize : dimSizes)
1159 materializedOperands.push_back(
1160 arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
1161 }
1162
1163 rewriter.setInsertionPointAfter(warpOp);
1164
1165 // Delinearize the lane ID for constructing the distributed mask sizes.
1166 SmallVector<Value> delinearizedIds;
1167 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1168 warpOp.getWarpSize(), warpOp.getLaneid(),
1169 delinearizedIds))
1170 return rewriter.notifyMatchFailure(
1171 mask, "cannot delinearize lane ID for distribution");
1172 assert(!delinearizedIds.empty());
1173
1174 // Notify the rewriter that the warp op is changing (see the comment on
1175 // the WarpOpTransferRead pattern).
1176 rewriter.startOpModification(warpOp);
1177
1178 AffineExpr s0, s1;
1179 bindSymbols(rewriter.getContext(), s0, s1);
1180 SmallVector<Value> newOperands;
1181 for (int i = 0, e = distShape.size(); i < e; ++i) {
1182 // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1183 // find the distance from the largest mask index owned by this lane to the
1184 // original mask size. `vector.create_mask` implicitly clamps mask
1185 // operands to the range [0, mask_vector_size[i]], or in other words, the
1186 // mask sizes are always in the range [0, mask_vector_size[i]).
1187 Value maskDimIdx = affine::makeComposedAffineApply(
1188 rewriter, loc, s1 - s0 * distShape[i],
1189 {delinearizedIds[i], materializedOperands[i]});
1190 newOperands.push_back(maskDimIdx);
1191 }
1192
1193 auto newMask =
1194 vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1195 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1196 rewriter.finalizeOpModification(warpOp);
1197 return success();
1198 }
1199};
1200
1201/// Sink out insert_strided_slice op feeding into a warp op yield.
1202/// ```
1203/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
1204/// ...
1205/// %src = ... : vector<4x32xf32>
1206/// %dest = ... : vector<8x32xf32>
1207/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1208/// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
1209/// gpu.yield %insert : vector<8x32xf32>
1210/// }
1211/// ```
1212/// To
1213/// ```
1214/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
1215/// vector<8x1xf32>) {
1216/// ...
1217/// %src = ... : vector<4x32xf32>
1218/// %dest = ... : vector<8x32xf32>
1219/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
1220/// }
1221/// %insert = vector.insert_strided_slice %0#0, %0#1,
1222/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
1223/// ```
1224/// NOTE: Current support assumes that both src and dest vectors are distributed
1225/// to lanes and sinking the insert op does not require any cross lane
1226/// communication.
1227struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
1228 using Base::Base;
1229 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1230 PatternRewriter &rewriter) const override {
1231 OpOperand *operand =
1232 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1233 if (!operand)
1234 return failure();
1235 unsigned int operandNumber = operand->getOperandNumber();
1236 auto insertOp =
1237 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1238 auto distributedType =
1239 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1240 // Distributed type must be 2D or higher.
1241 // TODO: Support 1D distributed types.
1242 if (distributedType.getRank() < 2)
1243 return rewriter.notifyMatchFailure(
1244 insertOp, "result vector type must be 2D or higher");
1245 // Find the distributed dimension of the dest vector. There should be
1246 // exactly one.
1247 auto yieldedType = cast<VectorType>(operand->get().getType());
1248 int64_t destDistributedDim =
1249 getDistributedDim(yieldedType, distributedType);
1250 assert(destDistributedDim != -1 && "could not find distributed dimension");
1251
1252 VectorType srcType = insertOp.getSourceVectorType();
1253 VectorType destType = insertOp.getDestVectorType();
1254 // Currently we require that both source (kD) and dest (nD) vectors are
1255 // distributed. This requires that distributedDim (d) is contained in the
1256 // last k dims of the dest vector (d >= n - k).
1257 // TODO: Add support for case where source vector is not distributed.
1258 int64_t sourceDistributedDim =
1259 destDistributedDim - (destType.getRank() - srcType.getRank());
1260 if (sourceDistributedDim < 0)
1261 return rewriter.notifyMatchFailure(
1262 insertOp,
1263 "distributed dimension must be in the last k dims of dest vector");
1264 // Distributed dimension must be fully inserted.
1265 if (srcType.getDimSize(sourceDistributedDim) !=
1266 destType.getDimSize(destDistributedDim))
1267 return rewriter.notifyMatchFailure(
1268 insertOp, "distributed dimension must be fully inserted");
1269 SmallVector<int64_t> newSourceDistShape(
1270 insertOp.getSourceVectorType().getShape());
1271 newSourceDistShape[sourceDistributedDim] =
1272 distributedType.getDimSize(destDistributedDim);
1273 auto newSourceTy =
1274 VectorType::get(newSourceDistShape, distributedType.getElementType());
1275 VectorType newDestTy = distributedType;
1276 SmallVector<size_t> newRetIndices;
1277 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1278 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1279 {newSourceTy, newDestTy}, newRetIndices);
1280 rewriter.setInsertionPointAfter(newWarpOp);
1281 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1282 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1283 // Create a new insert strided slice op that inserts distributed source into
1284 // distributed dest.
1285 Value newInsert = vector::InsertStridedSliceOp::create(
1286 rewriter, insertOp.getLoc(), distributedDest.getType(),
1287 distributedSource, distributedDest, insertOp.getOffsets(),
1288 insertOp.getStrides());
1289 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
1290 return success();
1291 }
1292};
1293
1294/// Sink out extract_strided_slice op feeding into a warp op yield.
1295/// ```
1296/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
1297/// ...
1298/// %src = ... : vector<64x32xf32>
1299/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1300/// strides = [1] : vector<64x32xf32> to vector<16x32xf32>
1301/// gpu.yield %extract : vector<16x32xf32>
1302/// }
1303/// ```
1304/// To
1305/// ```
1306/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
1307/// ...
1308/// %src = ... : vector<64x32xf32>
1309/// gpu.yield %src : vector<64x32xf32>
1310/// }
1311/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1312/// strides = [1] : vector<64x1xf32> to vector<16x1xf32>
1313/// ```
1314/// NOTE: Current support assumes that the extraction happens only on non
1315/// distributed dimensions (does not require cross lane communication).
1316struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
1317 using Base::Base;
1318 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1319 PatternRewriter &rewriter) const override {
1320 OpOperand *operand =
1321 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1322 if (!operand)
1323 return failure();
1324 unsigned int operandNumber = operand->getOperandNumber();
1325 auto extractOp =
1326 operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
1327 auto distributedType =
1328 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1329 // Distributed type must be 2D or higher.
1330 // TODO: Support 1D distributed types.
1331 if (distributedType.getRank() < 2)
1332 return rewriter.notifyMatchFailure(
1333 extractOp, "result vector type must be 2D or higher");
1334
1335 // Find the distributed dimension. There should be exactly one.
1336 auto yieldedType = cast<VectorType>(operand->get().getType());
1337 int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1338 assert(distributedDim != -1 && "could not find distributed dimension");
1339
1340 int64_t numOfExtractedDims =
1341 static_cast<int64_t>(extractOp.getSizes().size());
1342 // If the distributed dim is included in the extracted dims, then we make
1343 // sure distributed dim is fully extracted. If distributed dim is not
1344 // included in extracted dims, it is guaranteed to be fully extracted (i.e.
1345 // distributed dim comes after all the extracted dims)
1346 // TODO: Partial extraction from distributed dimension require cross lane
1347 // communication.
1348 if (distributedDim < numOfExtractedDims) {
1349 int64_t distributedDimOffset =
1350 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1351 .getInt();
1352 int64_t distributedDimSize =
1353 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1354 .getInt();
1355 if (distributedDimOffset != 0 ||
1356 distributedDimSize != yieldedType.getDimSize(distributedDim))
1357 return rewriter.notifyMatchFailure(
1358 extractOp, "distributed dimension must be fully extracted");
1359 }
1360 SmallVector<int64_t> newDistributedShape(
1361 extractOp.getSourceVectorType().getShape());
1362 newDistributedShape[distributedDim] =
1363 distributedType.getDimSize(distributedDim);
1364 auto newDistributedType =
1365 VectorType::get(newDistributedShape, distributedType.getElementType());
1366 SmallVector<size_t> newRetIndices;
1367 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1368 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1369 newRetIndices);
1370 rewriter.setInsertionPointAfter(newWarpOp);
1371 SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1372 extractOp.getSizes(), [](Attribute attr) { return attr; });
1373 // Update the distributed sizes to match the distributed type.
1374 if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
1375 distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1376 distributedType.getDimSize(distributedDim));
1377
1378 // Create a new extract strided slice op that extracts from the
1379 // distributed vector.
1380 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1381 Value newExtract = vector::ExtractStridedSliceOp::create(
1382 rewriter, extractOp.getLoc(), distributedType, distributedVec,
1383 extractOp.getOffsets(),
1384 ArrayAttr::get(rewriter.getContext(), distributedSizes),
1385 extractOp.getStrides());
1386 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1387 newExtract);
1388 return success();
1389 }
1390};
1391
1392/// Pattern to move out vector.extract of single element vector. Those don't
1393/// need to be distributed and can just be propagated outside of the region.
1394struct WarpOpExtract : public WarpDistributionPattern {
1395 using Base::Base;
1396 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1397 PatternRewriter &rewriter) const override {
1398 OpOperand *operand =
1399 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1400 if (!operand)
1401 return failure();
1402 unsigned int operandNumber = operand->getOperandNumber();
1403 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1404 VectorType extractSrcType = extractOp.getSourceVectorType();
1405 Location loc = extractOp.getLoc();
1406
1407 // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1408 if (extractSrcType.getRank() <= 1) {
1409 return failure();
1410 }
1411
1412 // All following cases are 2d or higher dimensional source vectors.
1413
1414 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1415 // There is no distribution, this is a broadcast. Simply move the extract
1416 // out of the warp op.
1417 // TODO: This could be optimized. E.g., in case of a scalar result, let
1418 // one lane extract and shuffle the result to all other lanes (same as
1419 // the 1d case).
1420 SmallVector<size_t> newRetIndices;
1421 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1422 rewriter, warpOp, {extractOp.getSource()},
1423 {extractOp.getSourceVectorType()}, newRetIndices);
1424 rewriter.setInsertionPointAfter(newWarpOp);
1425 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1426 // Extract from distributed vector.
1427 Value newExtract = vector::ExtractOp::create(
1428 rewriter, loc, distributedVec, extractOp.getMixedPosition());
1429 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1430 newExtract);
1431 return success();
1432 }
1433
1434 // Find the distributed dimension. There should be exactly one.
1435 auto distributedType =
1436 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1437 auto yieldedType = cast<VectorType>(operand->get().getType());
1438 int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1439 assert(distributedDim != -1 && "could not find distributed dimension");
1440 (void)distributedDim;
1441
1442 // Yield source vector from warp op.
1443 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1444 for (int i = 0; i < distributedType.getRank(); ++i)
1445 newDistributedShape[i + extractOp.getNumIndices()] =
1446 distributedType.getDimSize(i);
1447 auto newDistributedType =
1448 VectorType::get(newDistributedShape, distributedType.getElementType());
1449 SmallVector<size_t> newRetIndices;
1450 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1451 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1452 newRetIndices);
1453 rewriter.setInsertionPointAfter(newWarpOp);
1454 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1455 // Extract from distributed vector.
1456 Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1457 extractOp.getMixedPosition());
1458 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1459 newExtract);
1460 return success();
1461 }
1462};
1463
1464/// Pattern to move out vector.extract with a scalar result.
1465/// Only supports 1-D and 0-D sources for now.
1466struct WarpOpExtractScalar : public WarpDistributionPattern {
1467 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1468 PatternBenefit b = 1)
1469 : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1470 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1471 PatternRewriter &rewriter) const override {
1472 OpOperand *operand =
1473 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1474 if (!operand)
1475 return failure();
1476 unsigned int operandNumber = operand->getOperandNumber();
1477 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1478 VectorType extractSrcType = extractOp.getSourceVectorType();
1479 // Only supports 1-D or 0-D sources for now.
1480 if (extractSrcType.getRank() > 1) {
1481 return rewriter.notifyMatchFailure(
1482 extractOp, "only 0-D or 1-D source supported for now");
1483 }
1484 // TODO: Supported shuffle types should be parameterizable, similar to
1485 // `WarpShuffleFromIdxFn`.
1486 if (!extractSrcType.getElementType().isF32() &&
1487 !extractSrcType.getElementType().isInteger(32))
1488 return rewriter.notifyMatchFailure(
1489 extractOp, "only f32/i32 element types are supported");
1490 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1491 Type elType = extractSrcType.getElementType();
1492 VectorType distributedVecType;
1493 if (!is0dOrVec1Extract) {
1494 assert(extractSrcType.getRank() == 1 &&
1495 "expected that extract src rank is 0 or 1");
1496 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1497 return failure();
1498 int64_t elementsPerLane =
1499 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1500 distributedVecType = VectorType::get({elementsPerLane}, elType);
1501 } else {
1502 distributedVecType = extractSrcType;
1503 }
1504 // Yield source vector and position (if present) from warp op.
1505 SmallVector<Value> additionalResults{extractOp.getSource()};
1506 SmallVector<Type> additionalResultTypes{distributedVecType};
1507 additionalResults.append(
1508 SmallVector<Value>(extractOp.getDynamicPosition()));
1509 additionalResultTypes.append(
1510 SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1511
1512 Location loc = extractOp.getLoc();
1513 SmallVector<size_t> newRetIndices;
1514 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1515 rewriter, warpOp, additionalResults, additionalResultTypes,
1516 newRetIndices);
1517 rewriter.setInsertionPointAfter(newWarpOp);
1518 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1519
1520 // 0d extract: The new warp op broadcasts the source vector to all lanes.
1521 // All lanes extract the scalar.
1522 if (is0dOrVec1Extract) {
1523 Value newExtract;
1524 SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1525 newExtract =
1526 vector::ExtractOp::create(rewriter, loc, distributedVec, indices);
1527 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1528 newExtract);
1529 return success();
1530 }
1531
1532 int64_t staticPos = extractOp.getStaticPosition()[0];
1533 OpFoldResult pos = ShapedType::isDynamic(staticPos)
1534 ? (newWarpOp->getResult(newRetIndices[1]))
1535 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1536 // 1d extract: Distribute the source vector. One lane extracts and shuffles
1537 // the value to all other lanes.
1538 int64_t elementsPerLane = distributedVecType.getShape()[0];
1539 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1540 // tid of extracting thread: pos / elementsPerLane
1541 Value broadcastFromTid = affine::makeComposedAffineApply(
1542 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1543 // Extract at position: pos % elementsPerLane
1544 Value newPos =
1545 elementsPerLane == 1
1546 ? arith::ConstantIndexOp::create(rewriter, loc, 0).getResult()
1547 : affine::makeComposedAffineApply(rewriter, loc,
1548 sym0 % elementsPerLane, pos);
1549 Value extracted =
1550 vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1551
1552 // Shuffle the extracted value to all lanes.
1553 Value shuffled = warpShuffleFromIdxFn(
1554 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1555 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1556 return success();
1557 }
1558
1559private:
1560 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1561};
1562
1563/// Pattern to move out vector.insert with a scalar input.
1564/// Only supports 1-D and 0-D destinations for now.
1565struct WarpOpInsertScalar : public WarpDistributionPattern {
1566 using Base::Base;
1567 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1568 PatternRewriter &rewriter) const override {
1569 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1570 if (!operand)
1571 return failure();
1572 unsigned int operandNumber = operand->getOperandNumber();
1573 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1574 VectorType vecType = insertOp.getDestVectorType();
1575 VectorType distrType =
1576 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1577
1578 // Only supports 1-D or 0-D destinations for now.
1579 if (vecType.getRank() > 1) {
1580 return rewriter.notifyMatchFailure(
1581 insertOp, "only 0-D or 1-D source supported for now");
1582 }
1583
1584 // Yield destination vector, source scalar and position from warp op.
1585 SmallVector<Value> additionalResults{insertOp.getDest(),
1586 insertOp.getValueToStore()};
1587 SmallVector<Type> additionalResultTypes{
1588 distrType, insertOp.getValueToStore().getType()};
1589 additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1590 additionalResultTypes.append(
1591 SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1592
1593 Location loc = insertOp.getLoc();
1594 SmallVector<size_t> newRetIndices;
1595 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1596 rewriter, warpOp, additionalResults, additionalResultTypes,
1597 newRetIndices);
1598 rewriter.setInsertionPointAfter(newWarpOp);
1599 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1600 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1601 rewriter.setInsertionPointAfter(newWarpOp);
1602
1603 OpFoldResult pos;
1604 if (vecType.getRank() != 0) {
1605 int64_t staticPos = insertOp.getStaticPosition()[0];
1606 pos = ShapedType::isDynamic(staticPos)
1607 ? (newWarpOp->getResult(newRetIndices[2]))
1608 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1609 }
1610
1611 // This condition is always true for 0-d vectors.
1612 if (vecType == distrType) {
1613 Value newInsert;
1614 SmallVector<OpFoldResult> indices;
1615 if (pos) {
1616 indices.push_back(pos);
1617 }
1618 newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1619 distributedVec, indices);
1620 // Broadcast: Simply move the vector.insert op out.
1621 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1622 newInsert);
1623 return success();
1624 }
1625
1626 // This is a distribution. Only one lane should insert.
1627 int64_t elementsPerLane = distrType.getShape()[0];
1628 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1629 // tid of extracting thread: pos / elementsPerLane
1630 Value insertingLane = affine::makeComposedAffineApply(
1631 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1632 // Insert position: pos % elementsPerLane
1633 OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
1634 rewriter, loc, sym0 % elementsPerLane, pos);
1635 Value isInsertingLane =
1636 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1637 newWarpOp.getLaneid(), insertingLane);
1638 Value newResult =
1639 scf::IfOp::create(
1640 rewriter, loc, isInsertingLane,
1641 /*thenBuilder=*/
1642 [&](OpBuilder &builder, Location loc) {
1643 Value newInsert = vector::InsertOp::create(
1644 builder, loc, newSource, distributedVec, newPos);
1645 scf::YieldOp::create(builder, loc, newInsert);
1646 },
1647 /*elseBuilder=*/
1648 [&](OpBuilder &builder, Location loc) {
1649 scf::YieldOp::create(builder, loc, distributedVec);
1650 })
1651 .getResult(0);
1652 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1653 return success();
1654 }
1655};
1656
1657struct WarpOpInsert : public WarpDistributionPattern {
1658 using Base::Base;
1659 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1660 PatternRewriter &rewriter) const override {
1661 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1662 if (!operand)
1663 return failure();
1664 unsigned int operandNumber = operand->getOperandNumber();
1665 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1666 Location loc = insertOp.getLoc();
1667
1668 // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1669 if (insertOp.getDestVectorType().getRank() <= 1) {
1670 return failure();
1671 }
1672
1673 // All following cases are 2d or higher dimensional source vectors.
1674
1675 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1676 // There is no distribution, this is a broadcast. Simply move the insert
1677 // out of the warp op.
1678 SmallVector<size_t> newRetIndices;
1679 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1680 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1681 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1682 newRetIndices);
1683 rewriter.setInsertionPointAfter(newWarpOp);
1684 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1685 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1686 Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1687 distributedDest,
1688 insertOp.getMixedPosition());
1689 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1690 newResult);
1691 return success();
1692 }
1693
1694 // Find the distributed dimension. There should be exactly one.
1695 auto distrDestType =
1696 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1697 auto yieldedType = cast<VectorType>(operand->get().getType());
1698 int64_t distrDestDim = -1;
1699 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1700 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1701 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1702 // support distributing multiple dimensions in the future.
1703 assert(distrDestDim == -1 && "found multiple distributed dims");
1704 distrDestDim = i;
1705 }
1706 }
1707 assert(distrDestDim != -1 && "could not find distributed dimension");
1708
1709 // Compute the distributed source vector type.
1710 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1711 SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1712 // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1713 // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1714 // insert a smaller vector<3xf32>.
1715 // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1716 // case, one lane will insert the source vector<96xf32>. The other
1717 // lanes will not do anything.
1718 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1719 if (distrSrcDim >= 0)
1720 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1721 auto distrSrcType =
1722 VectorType::get(distrSrcShape, distrDestType.getElementType());
1723
1724 // Yield source and dest vectors from warp op.
1725 SmallVector<size_t> newRetIndices;
1726 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1727 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1728 {distrSrcType, distrDestType}, newRetIndices);
1729 rewriter.setInsertionPointAfter(newWarpOp);
1730 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1731 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1732
1733 // Insert into the distributed vector.
1734 Value newResult;
1735 if (distrSrcDim >= 0) {
1736 // Every lane inserts a small piece.
1737 newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1738 distributedDest,
1739 insertOp.getMixedPosition());
1740 } else {
1741 // One lane inserts the entire source vector.
1742 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1743 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1744 SmallVector<int64_t> newPos = getAsIntegers(pos);
1745 // tid of inserting lane: pos / elementsPerLane
1746 Value insertingLane = arith::ConstantIndexOp::create(
1747 rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1748 Value isInsertingLane =
1749 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1750 newWarpOp.getLaneid(), insertingLane);
1751 // Insert position: pos % elementsPerLane
1752 newPos[distrDestDim] %= elementsPerLane;
1753 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1754 Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1755 distributedDest, newPos);
1756 scf::YieldOp::create(builder, loc, newInsert);
1757 };
1758 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1759 scf::YieldOp::create(builder, loc, distributedDest);
1760 };
1761 newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1762 /*thenBuilder=*/insertingBuilder,
1763 /*elseBuilder=*/nonInsertingBuilder)
1764 .getResult(0);
1765 }
1766
1767 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1768 return success();
1769 }
1770};
1771
1772/// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
1773/// the scf.if is the last operation in the region so that it doesn't
1774/// change the order of execution. This creates a new scf.if after the
1775/// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
1776/// the "inner" WarpExecuteOnLane0Op. Example:
1777/// ```
1778/// gpu.warp_execute_on_lane_0(%laneid)[32] {
1779/// %payload = ... : vector<32xindex>
1780/// scf.if %pred {
1781/// vector.store %payload, %buffer[%idx] : memref<128xindex>,
1782/// vector<32xindex>
1783/// }
1784/// gpu.yield
1785/// }
1786/// ```
1787/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
1788/// %payload = ... : vector<32xindex>
1789/// gpu.yield %payload : vector<32xindex>
1790/// }
1791/// scf.if %pred {
1792/// gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
1793/// ^bb0(%arg1: vector<32xindex>):
1794/// vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
1795/// }
1796/// }
1797/// ```
1798struct WarpOpScfIfOp : public WarpDistributionPattern {
1799 WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1800 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1801 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1802 PatternRewriter &rewriter) const override {
1803 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1804 // Only pick up `IfOp` if it is the last op in the region.
1805 Operation *lastNode = warpOpYield->getPrevNode();
1806 auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1807 if (!ifOp)
1808 return failure();
1809
1810 // The current `WarpOp` can yield two types of values:
1811 // 1. Not results of `IfOp`:
1812 // Preserve them in the new `WarpOp`.
1813 // Collect their yield index to remap the usages.
1814 // 2. Results of `IfOp`:
1815 // They are not part of the new `WarpOp` results.
1816 // Map current warp's yield operand index to `IfOp` result idx.
1817 SmallVector<Value> nonIfYieldValues;
1818 SmallVector<unsigned> nonIfYieldIndices;
1819 llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1820 llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1821 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1822 const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
1823 if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
1824 nonIfYieldValues.push_back(yieldOperand.get());
1825 nonIfYieldIndices.push_back(yieldOperandIdx);
1826 continue;
1827 }
1828 OpResult ifResult = cast<OpResult>(yieldOperand.get());
1829 const unsigned ifResultIdx = ifResult.getResultNumber();
1830 ifResultMapping[yieldOperandIdx] = ifResultIdx;
1831 // If this `ifOp` result is vector type and it is yielded by the
1832 // `WarpOp`, we keep track the distributed type for this result.
1833 if (!isa<VectorType>(ifResult.getType()))
1834 continue;
1835 VectorType distType =
1836 cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1837 ifResultDistTypes[ifResultIdx] = distType;
1838 }
1839
1840 // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
1841 // them
1842 auto [escapingValuesThen, escapingValueInputTypesThen,
1843 escapingValueDistTypesThen] =
1844 getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1845 distributionMapFn);
1846 auto [escapingValuesElse, escapingValueInputTypesElse,
1847 escapingValueDistTypesElse] =
1848 getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1849 distributionMapFn);
1850 if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1851 llvm::is_contained(escapingValueDistTypesElse, Type{}))
1852 return failure();
1853
1854 // The new `WarpOp` groups yields values in following order:
1855 // 1. Branch condition
1856 // 2. Escaping values then branch
1857 // 3. Escaping values else branch
1858 // 4. All non-`ifOp` yielded values.
1859 SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1860 newWarpOpYieldValues.append(escapingValuesThen.begin(),
1861 escapingValuesThen.end());
1862 newWarpOpYieldValues.append(escapingValuesElse.begin(),
1863 escapingValuesElse.end());
1864 SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1865 newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1866 escapingValueDistTypesThen.end());
1867 newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1868 escapingValueDistTypesElse.end());
1869
1870 for (auto [idx, val] :
1871 llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1872 newWarpOpYieldValues.push_back(val);
1873 newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1874 }
1875 // Replace the old `WarpOp` with the new one that has additional yield
1876 // values and types.
1877 SmallVector<size_t> newIndices;
1878 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1879 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
1880 // `ifOp` returns the result of the inner warp op.
1881 SmallVector<Type> newIfOpDistResTypes;
1882 for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1883 Type distType = cast<Value>(res).getType();
1884 if (auto vecType = dyn_cast<VectorType>(distType)) {
1885 AffineMap map = distributionMapFn(cast<Value>(res));
1886 // Fallback to affine map if the dist result was not previously recorded
1887 distType = ifResultDistTypes.count(i)
1888 ? ifResultDistTypes[i]
1889 : getDistributedType(vecType, map, warpOp.getWarpSize());
1890 }
1891 newIfOpDistResTypes.push_back(distType);
1892 }
1893 // Create a new `IfOp` outside the new `WarpOp` region.
1894 OpBuilder::InsertionGuard g(rewriter);
1895 rewriter.setInsertionPointAfter(newWarpOp);
1896 auto newIfOp = scf::IfOp::create(
1897 rewriter, ifOp.getLoc(), newIfOpDistResTypes,
1898 newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
1899 static_cast<bool>(ifOp.elseBlock()));
1900 auto encloseRegionInWarpOp =
1901 [&](Block *oldIfBranch, Block *newIfBranch,
1902 llvm::SmallSetVector<Value, 32> &escapingValues,
1903 SmallVector<Type> &escapingValueInputTypes,
1904 size_t warpResRangeStart) {
1905 OpBuilder::InsertionGuard g(rewriter);
1906 if (!newIfBranch)
1907 return;
1908 rewriter.setInsertionPointToStart(newIfBranch);
1909 llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1910 SmallVector<Value> innerWarpInputVals;
1911 SmallVector<Type> innerWarpInputTypes;
1912 for (size_t i = 0; i < escapingValues.size();
1913 ++i, ++warpResRangeStart) {
1914 innerWarpInputVals.push_back(
1915 newWarpOp.getResult(newIndices[warpResRangeStart]));
1916 escapeValToBlockArgIndex[escapingValues[i]] =
1917 innerWarpInputTypes.size();
1918 innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1919 }
1920 auto innerWarp = WarpExecuteOnLane0Op::create(
1921 rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1922 newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1923 innerWarpInputVals, innerWarpInputTypes);
1924
1925 innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
1926 innerWarp.getWarpRegion().addArguments(
1927 innerWarpInputTypes,
1928 SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1929
1930 SmallVector<Value> yieldOperands;
1931 for (Value operand : oldIfBranch->getTerminator()->getOperands())
1932 yieldOperands.push_back(operand);
1933 rewriter.eraseOp(oldIfBranch->getTerminator());
1934
1935 rewriter.setInsertionPointToEnd(innerWarp.getBody());
1936 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1937 rewriter.setInsertionPointAfter(innerWarp);
1938 scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1939
1940 // Update any users of escaping values that were forwarded to the
1941 // inner `WarpOp`. These values are arguments of the inner `WarpOp`.
1942 innerWarp.walk([&](Operation *op) {
1943 for (OpOperand &operand : op->getOpOperands()) {
1944 auto it = escapeValToBlockArgIndex.find(operand.get());
1945 if (it == escapeValToBlockArgIndex.end())
1946 continue;
1947 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1948 }
1949 });
1950 mlir::vector::moveScalarUniformCode(innerWarp);
1951 };
1952 encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
1953 &newIfOp.getThenRegion().front(), escapingValuesThen,
1954 escapingValueInputTypesThen, 1);
1955 if (!ifOp.getElseRegion().empty())
1956 encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
1957 &newIfOp.getElseRegion().front(),
1958 escapingValuesElse, escapingValueInputTypesElse,
1959 1 + escapingValuesThen.size());
1960 // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
1961 // result.
1962 for (auto [origIdx, newIdx] : ifResultMapping)
1963 rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
1964 newIfOp.getResult(newIdx), newIfOp);
1965 return success();
1966 }
1967
1968private:
1969 DistributionMapFn distributionMapFn;
1970};
1971
1972/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1973/// the scf.ForOp is the last operation in the region so that it doesn't
1974/// change the order of execution. This creates a new scf.for region after the
1975/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1976/// WarpExecuteOnLane0Op region. Example:
1977/// ```
1978/// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1979/// ...
1980/// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1981/// -> (vector<128xf32>) {
1982/// ...
1983/// scf.yield %r : vector<128xf32>
1984/// }
1985/// gpu.yield %v1 : vector<128xf32>
1986/// }
1987/// ```
1988/// To:
1989/// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1990/// ...
1991/// gpu.yield %v : vector<128xf32>
1992/// }
1993/// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1994/// -> (vector<4xf32>) {
1995/// %iw = gpu.warp_execute_on_lane_0(%laneid)
1996/// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1997/// ^bb0(%arg: vector<128xf32>):
1998/// ...
1999/// gpu.yield %ir : vector<128xf32>
2000/// }
2001/// scf.yield %iw : vector<4xf32>
2002/// }
2003/// ```
2004struct WarpOpScfForOp : public WarpDistributionPattern {
2005
2006 WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
2007 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
2008 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2009 PatternRewriter &rewriter) const override {
2010 gpu::YieldOp warpOpYield = warpOp.getTerminator();
2011 // Only pick up `ForOp` if it is the last op in the region.
2012 Operation *lastNode = warpOpYield->getPrevNode();
2013 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
2014 if (!forOp)
2015 return failure();
2016 // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
2017 // Those Values need to be returned by the new warp op.
2018 auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2019 getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2020 distributionMapFn);
2021 if (llvm::is_contained(escapingValueDistTypes, Type{}))
2022 return failure();
2023 // `WarpOp` can yield two types of values:
2024 // 1. Values that are not results of the `ForOp`:
2025 // These values must also be yielded by the new `WarpOp`. Also, we need
2026 // to record the index mapping for these values to replace them later.
2027 // 2. Values that are results of the `ForOp`:
2028 // In this case, we record the index mapping between the `WarpOp` result
2029 // index and matching `ForOp` result index.
2030 // Additionally, we keep track of the distributed types for all `ForOp`
2031 // vector results.
2032 SmallVector<Value> nonForYieldedValues;
2033 SmallVector<unsigned> nonForResultIndices;
2034 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
2035 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
2036 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
2037 // Yielded value is not a result of the forOp.
2038 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
2039 nonForYieldedValues.push_back(yieldOperand.get());
2040 nonForResultIndices.push_back(yieldOperand.getOperandNumber());
2041 continue;
2042 }
2043 OpResult forResult = cast<OpResult>(yieldOperand.get());
2044 unsigned int forResultNumber = forResult.getResultNumber();
2045 forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber;
2046 // If this `ForOp` result is vector type and it is yielded by the
2047 // `WarpOp`, we keep track the distributed type for this result.
2048 if (!isa<VectorType>(forResult.getType()))
2049 continue;
2050 VectorType distType = cast<VectorType>(
2051 warpOp.getResult(yieldOperand.getOperandNumber()).getType());
2052 forResultDistTypes[forResultNumber] = distType;
2053 }
2054
2055 // Newly created `WarpOp` will yield values in following order:
2056 // 1. Loop bounds.
2057 // 2. All init args of the `ForOp`.
2058 // 3. All escaping values.
2059 // 4. All non-`ForOp` yielded values.
2060 SmallVector<Value> newWarpOpYieldValues;
2061 SmallVector<Type> newWarpOpDistTypes;
2062 newWarpOpYieldValues.insert(
2063 newWarpOpYieldValues.end(),
2064 {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
2065 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2066 {forOp.getLowerBound().getType(),
2067 forOp.getUpperBound().getType(),
2068 forOp.getStep().getType()});
2069 for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
2070 newWarpOpYieldValues.push_back(initArg);
2071 // Compute the distributed type for this init arg.
2072 Type distType = initArg.getType();
2073 if (auto vecType = dyn_cast<VectorType>(distType)) {
2074 // If the `ForOp` result corresponds to this init arg is already yielded
2075 // we can get the distributed type from `forResultDistTypes` map.
2076 // Otherwise, we compute it using distributionMapFn.
2077 AffineMap map = distributionMapFn(initArg);
2078 distType = forResultDistTypes.count(i)
2079 ? forResultDistTypes[i]
2080 : getDistributedType(vecType, map, warpOp.getWarpSize());
2081 }
2082 newWarpOpDistTypes.push_back(distType);
2083 }
2084 // Insert escaping values and their distributed types.
2085 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
2086 escapingValues.begin(), escapingValues.end());
2087 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2088 escapingValueDistTypes.begin(),
2089 escapingValueDistTypes.end());
2090 // Next, we insert all non-`ForOp` yielded values and their distributed
2091 // types.
2092 for (auto [i, v] :
2093 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2094 newWarpOpYieldValues.push_back(v);
2095 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
2096 }
2097 // Create the new `WarpOp` with the updated yield values and types.
2098 SmallVector<size_t> newIndices;
2099 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2100 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
2101
2102 // Next, we create a new `ForOp` with the init args yielded by the new
2103 // `WarpOp`.
2104 const unsigned initArgsStartIdx = 3; // After loop bounds.
2105 const unsigned escapingValuesStartIdx =
2106 initArgsStartIdx +
2107 forOp.getInitArgs().size(); // `ForOp` init args are positioned before
2108 // escaping values in the new `WarpOp`.
2109 SmallVector<Value> newForOpOperands;
2110 for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
2111 newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
2112
2113 // Create a new `ForOp` outside the new `WarpOp` region.
2114 OpBuilder::InsertionGuard g(rewriter);
2115 rewriter.setInsertionPointAfter(newWarpOp);
2116 auto newForOp = scf::ForOp::create(
2117 rewriter, forOp.getLoc(),
2118 /**LowerBound=**/ newWarpOp.getResult(newIndices[0]),
2119 /**UpperBound=**/ newWarpOp.getResult(newIndices[1]),
2120 /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands,
2121 /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
2122 // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
2123 // newly created `ForOp`. This `WarpOp` will contain all ops that were
2124 // contained within the original `ForOp` body.
2125 rewriter.setInsertionPointToStart(newForOp.getBody());
2126
2127 SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
2128 newForOp.getRegionIterArgs().end());
2129 SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
2130 forOp.getResultTypes().end());
2131 // Escaping values are forwarded to the inner `WarpOp` as its (additional)
2132 // arguments. We keep track of the mapping between these values and their
2133 // argument index in the inner `WarpOp` (to replace users later).
2134 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
2135 for (size_t i = escapingValuesStartIdx;
2136 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2137 innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
2138 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
2139 innerWarpInputType.size();
2140 innerWarpInputType.push_back(
2141 escapingValueInputTypes[i - escapingValuesStartIdx]);
2142 }
2143 // Create the inner `WarpOp` with the new input values and types.
2144 auto innerWarp = WarpExecuteOnLane0Op::create(
2145 rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
2146 newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
2147 innerWarpInputType);
2148
2149 // Inline the `ForOp` body into the inner `WarpOp` body.
2150 SmallVector<Value> argMapping;
2151 argMapping.push_back(newForOp.getInductionVar());
2152 for (Value args : innerWarp.getBody()->getArguments())
2153 argMapping.push_back(args);
2154
2155 argMapping.resize(forOp.getBody()->getNumArguments());
2156 SmallVector<Value> yieldOperands;
2157 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
2158 yieldOperands.push_back(operand);
2159
2160 rewriter.eraseOp(forOp.getBody()->getTerminator());
2161 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
2162
2163 // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
2164 // original `ForOp` results.
2165 rewriter.setInsertionPointToEnd(innerWarp.getBody());
2166 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
2167 rewriter.setInsertionPointAfter(innerWarp);
2168 // Insert a scf.yield op at the end of the new `ForOp` body that yields
2169 // the inner `WarpOp` results.
2170 if (!innerWarp.getResults().empty())
2171 scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
2172
2173 // Update the users of the new `WarpOp` results that were coming from the
2174 // original `ForOp` to the corresponding new `ForOp` result.
2175 for (auto [origIdx, newIdx] : forResultMapping)
2176 rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
2177 newForOp.getResult(newIdx), newForOp);
2178 // Update any users of escaping values that were forwarded to the
2179 // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
2180 newForOp.walk([&](Operation *op) {
2181 for (OpOperand &operand : op->getOpOperands()) {
2182 auto it = argIndexMapping.find(operand.get());
2183 if (it == argIndexMapping.end())
2184 continue;
2185 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
2186 }
2187 });
2188
2189 // Finally, hoist out any now uniform code from the inner `WarpOp`.
2190 mlir::vector::moveScalarUniformCode(innerWarp);
2191 return success();
2192 }
2193
2194private:
2195 DistributionMapFn distributionMapFn;
2196};
2197
2198/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
2199/// The vector is reduced in parallel. Currently limited to vector size
2200/// matching the warpOp size. E.g.:
2201/// ```
2202/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
2203/// %0 = "some_def"() : () -> (vector<32xf32>)
2204/// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
2205/// gpu.yield %1 : f32
2206/// }
2207/// ```
2208/// is lowered to:
2209/// ```
2210/// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
2211/// %1 = "some_def"() : () -> (vector<32xf32>)
2212/// gpu.yield %1 : vector<32xf32>
2213/// }
2214/// %a = vector.extract %0[0] : f32 from vector<1xf32>
2215/// %r = ("warp.reduction %a")
2216/// ```
2217struct WarpOpReduction : public WarpDistributionPattern {
2218 WarpOpReduction(MLIRContext *context,
2219 DistributedReductionFn distributedReductionFn,
2220 PatternBenefit benefit = 1)
2221 : WarpDistributionPattern(context, benefit),
2222 distributedReductionFn(std::move(distributedReductionFn)) {}
2223
2224 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2225 PatternRewriter &rewriter) const override {
2226 OpOperand *yieldOperand =
2227 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
2228 if (!yieldOperand)
2229 return failure();
2230
2231 auto reductionOp =
2232 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
2233 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
2234 // Only rank 1 vectors supported.
2235 if (vectorType.getRank() != 1)
2236 return rewriter.notifyMatchFailure(
2237 warpOp, "Only rank 1 reductions can be distributed.");
2238 // Only warp_size-sized vectors supported.
2239 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2240 return rewriter.notifyMatchFailure(
2241 warpOp, "Reduction vector dimension must match was size.");
2242 if (!reductionOp.getType().isIntOrFloat())
2243 return rewriter.notifyMatchFailure(
2244 warpOp, "Reduction distribution currently only supports floats and "
2245 "integer types.");
2246
2247 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2248 // Return vector that will be reduced from the WarpExecuteOnLane0Op.
2249 unsigned operandIndex = yieldOperand->getOperandNumber();
2250 SmallVector<Value> yieldValues = {reductionOp.getVector()};
2251 SmallVector<Type> retTypes = {
2252 VectorType::get({numElements}, reductionOp.getType())};
2253 if (reductionOp.getAcc()) {
2254 yieldValues.push_back(reductionOp.getAcc());
2255 retTypes.push_back(reductionOp.getAcc().getType());
2256 }
2257 SmallVector<size_t> newRetIndices;
2258 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2259 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2260 rewriter.setInsertionPointAfter(newWarpOp);
2261
2262 // Obtain data to reduce for a single lane.
2263 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2264 // Distribute and reduce across threads.
2265 Value fullReduce =
2266 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2267 reductionOp.getKind(), newWarpOp.getWarpSize());
2268 if (reductionOp.getAcc()) {
2269 fullReduce = vector::makeArithReduction(
2270 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2271 newWarpOp.getResult(newRetIndices[1]));
2272 }
2273 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
2274 return success();
2275 }
2276
2277private:
2278 DistributedReductionFn distributedReductionFn;
2279};
2280
2281} // namespace
2282
2288
2289void mlir::vector::populateDistributeTransferWriteOpPatterns(
2290 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2291 unsigned maxNumElementsToExtract, PatternBenefit benefit) {
2292 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
2293 maxNumElementsToExtract, benefit);
2294}
2295
2296void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2297 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2298 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
2299 PatternBenefit readBenefit) {
2300 patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
2301 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2302 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
2303 WarpOpConstant, WarpOpInsertScalar, WarpOpInsert,
2304 WarpOpCreateMask<vector::CreateMaskOp>,
2305 WarpOpCreateMask<vector::ConstantMaskOp>,
2306 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2307 patterns.getContext(), benefit);
2308 patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
2309 benefit);
2310 patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
2311 benefit);
2312 patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
2313 benefit);
2314}
2315
2316void mlir::vector::populateDistributeReduction(
2318 const DistributedReductionFn &distributedReductionFn,
2319 PatternBenefit benefit) {
2320 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
2321 benefit);
2322}
2323
2324/// Helper to know if an op can be hoisted out of the region.
2325static bool canBeHoisted(Operation *op,
2326 function_ref<bool(Value)> definedOutside) {
2327 return llvm::all_of(op->getOperands(), definedOutside) &&
2328 isMemoryEffectFree(op) && op->getNumRegions() == 0;
2329}
2330
2331void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2332 Block *body = warpOp.getBody();
2333
2334 // Keep track of the ops we want to hoist.
2335 llvm::SmallSetVector<Operation *, 8> opsToMove;
2336
2337 // Helper to check if a value is or will be defined outside of the region.
2338 auto isDefinedOutsideOfBody = [&](Value value) {
2339 auto *definingOp = value.getDefiningOp();
2340 return (definingOp && opsToMove.count(definingOp)) ||
2341 warpOp.isDefinedOutsideOfRegion(value);
2342 };
2343
2344 // Do not use walk here, as we do not want to go into nested regions and hoist
2345 // operations from there.
2346 for (auto &op : body->without_terminator()) {
2347 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
2348 return isa<VectorType>(result.getType());
2349 });
2350 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
2351 opsToMove.insert(&op);
2352 }
2353
2354 // Move all the ops marked as uniform outside of the region.
2355 for (Operation *op : opsToMove)
2356 op->moveBefore(warpOp);
2357}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
unsigned getNumOperands()
Definition Operation.h:346
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.