MLIR 23.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1//===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/Attributes.h"
22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24#include "llvm/Support/FormatVariadic.h"
25#include <utility>
26
27using namespace mlir;
28using namespace mlir::vector;
29using namespace mlir::gpu;
30
31/// Currently the distribution map is implicit based on the vector shape. In the
32/// future it will be part of the op.
33/// Example:
34/// ```
35/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
36/// ...
37/// gpu.yield %3 : vector<32x16x64xf32>
38/// }
39/// ```
40/// Would have an implicit map of:
41/// `(d0, d1, d2) -> (d0, d2)`
42static AffineMap calculateImplicitMap(VectorType sequentialType,
43 VectorType distributedType) {
45 perm.reserve(1);
46 // Check which dimensions of the sequential type are different than the
47 // dimensions of the distributed type to know the distributed dimensions. Then
48 // associate each distributed dimension to an ID in order.
49 for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
51 perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
52 }
53 auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
54 distributedType.getContext());
55 return map;
56}
57
58/// Given a sequential and distributed vector type, returns the distributed
59/// dimension. This function expects that only a single dimension is
60/// distributed.
61static int getDistributedDim(VectorType sequentialType,
62 VectorType distributedType) {
63 assert(sequentialType.getRank() == distributedType.getRank() &&
64 "sequential and distributed vector types must have the same rank");
65 int64_t distributedDim = -1;
66 for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
67 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
68 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
69 // support distributing multiple dimensions in the future.
70 assert(distributedDim == -1 && "found multiple distributed dims");
71 distributedDim = i;
72 }
73 }
74 return distributedDim;
75}
76
77namespace {
78
79/// Helper struct to create the load / store operations that permit transit
80/// through the parallel / sequential and the sequential / parallel boundaries
81/// when performing `rewriteWarpOpToScfFor`.
82///
83/// The vector distribution dimension is inferred from the vector types.
84struct DistributedLoadStoreHelper {
85 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
86 Value laneId, Value zero)
87 : sequentialVal(sequentialVal), distributedVal(distributedVal),
88 laneId(laneId), zero(zero) {
89 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
90 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
91 if (sequentialVectorType && distributedVectorType)
92 distributionMap =
93 calculateImplicitMap(sequentialVectorType, distributedVectorType);
94 }
95
96 Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
97 int64_t distributedSize = distributedVectorType.getDimSize(index);
98 AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
99 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
100 ArrayRef<Value>{laneId});
101 }
102
103 /// Create a store during the process of distributing the
104 /// `vector.warp_execute_on_thread_0` op.
105 /// Vector distribution assumes the following convention regarding the
106 /// temporary buffers that are created to transition values. This **must**
107 /// be properly specified in the `options.warpAllocationFn`:
108 /// 1. scalars of type T transit through a memref<1xT>.
109 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
110 Operation *buildStore(RewriterBase &b, Location loc, Value val,
111 Value buffer) {
112 assert((val == distributedVal || val == sequentialVal) &&
113 "Must store either the preregistered distributed or the "
114 "preregistered sequential value.");
115 // Scalar case can directly use memref.store.
116 if (!isa<VectorType>(val.getType()))
117 return memref::StoreOp::create(b, loc, val, buffer, zero);
118
119 // Vector case must use vector::TransferWriteOp which will later lower to
120 // vector.store of memref.store depending on further lowerings.
121 int64_t rank = sequentialVectorType.getRank();
122 SmallVector<Value> indices(rank, zero);
123 if (val == distributedVal) {
124 for (auto dimExpr : distributionMap.getResults()) {
125 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126 indices[index] = buildDistributedOffset(b, loc, index);
127 }
128 }
129 SmallVector<bool> inBounds(indices.size(), true);
130 return vector::TransferWriteOp::create(
131 b, loc, val, buffer, indices,
132 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
133 }
134
135 /// Create a load during the process of distributing the
136 /// `vector.warp_execute_on_thread_0` op.
137 /// Vector distribution assumes the following convention regarding the
138 /// temporary buffers that are created to transition values. This **must**
139 /// be properly specified in the `options.warpAllocationFn`:
140 /// 1. scalars of type T transit through a memref<1xT>.
141 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
142 ///
143 /// When broadcastMode is true, the load is not distributed to account for
144 /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
145 ///
146 /// Example:
147 ///
148 /// ```
149 /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
150 /// gpu.yield %cst : f32
151 /// }
152 /// // Both types are f32. The constant %cst is broadcasted to all lanes.
153 /// ```
154 /// This behavior described in more detail in the documentation of the op.
155 Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
156
157 // Scalar case can directly use memref.store.
158 if (!isa<VectorType>(type))
159 return memref::LoadOp::create(b, loc, buffer, zero);
160
161 // Other cases must be vector atm.
162 // Vector case must use vector::TransferReadOp which will later lower to
163 // vector.read of memref.read depending on further lowerings.
164 assert((type == distributedVectorType || type == sequentialVectorType) &&
165 "Must store either the preregistered distributed or the "
166 "preregistered sequential type.");
167 SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
168 if (type == distributedVectorType) {
169 for (auto dimExpr : distributionMap.getResults()) {
170 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171 indices[index] = buildDistributedOffset(b, loc, index);
172 }
173 }
174 SmallVector<bool> inBounds(indices.size(), true);
175 return vector::TransferReadOp::create(
176 b, loc, cast<VectorType>(type), buffer, indices,
177 /*padding=*/std::nullopt,
178 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
179 }
180
181 Value sequentialVal, distributedVal, laneId, zero;
182 VectorType sequentialVectorType, distributedVectorType;
183 AffineMap distributionMap;
184};
185
186} // namespace
187
188// Clones `op` into a new operation that takes `operands` and returns
189// `resultTypes`.
191 Location loc, Operation *op,
192 ArrayRef<Value> operands,
193 ArrayRef<Type> resultTypes) {
194 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
195 op->getAttrs());
196 return rewriter.create(res);
197}
198
199namespace {
200
201/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
202/// thread `laneId` executes the entirety of the computation.
203///
204/// After the transformation:
205/// - the IR within the scf.if op can be thought of as executing sequentially
206/// (from the point of view of threads along `laneId`).
207/// - the IR outside of the scf.if op can be thought of as executing in
208/// parallel (from the point of view of threads along `laneId`).
209///
210/// Values that need to transit through the parallel / sequential and the
211/// sequential / parallel boundaries do so via reads and writes to a temporary
212/// memory location.
213///
214/// The transformation proceeds in multiple steps:
215/// 1. Create the scf.if op.
216/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
217/// within the scf.if to transit the values captured from above.
218/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
219/// consistent within the scf.if.
220/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
221/// 5. Insert appropriate writes within scf.if and reads after the scf.if to
222/// transit the values returned by the op.
223/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
224/// consistent after the scf.if.
225/// 7. Perform late cleanups.
226///
227/// All this assumes the vector distribution occurs along the most minor
228/// distributed vector dimension.
229struct WarpOpToScfIfPattern : public WarpDistributionPattern {
230 WarpOpToScfIfPattern(MLIRContext *context,
231 const WarpExecuteOnLane0LoweringOptions &options,
232 PatternBenefit benefit = 1)
233 : WarpDistributionPattern(context, benefit), options(options) {}
234
235 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
236 PatternRewriter &rewriter) const override {
237 assert(warpOp.getBodyRegion().hasOneBlock() &&
238 "expected WarpOp with single block");
239 Block *warpOpBody = &warpOp.getBodyRegion().front();
240 Location loc = warpOp.getLoc();
241
242 // Passed all checks. Start rewriting.
243 OpBuilder::InsertionGuard g(rewriter);
244 rewriter.setInsertionPoint(warpOp);
245
246 // Step 1: Create scf.if op.
247 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
248 Value isLane0 = arith::CmpIOp::create(
249 rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
250 auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
251 /*withElseRegion=*/false);
252 rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
253
254 // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
255 // reads within the scf.if to transit the values captured from above.
256 SmallVector<Value> bbArgReplacements;
257 for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
258 Value sequentialVal = warpOpBody->getArgument(it.index());
259 Value distributedVal = it.value();
260 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
261 warpOp.getLaneid(), c0);
262
263 // Create buffer before the ifOp.
264 rewriter.setInsertionPoint(ifOp);
265 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
266 sequentialVal.getType());
267 // Store distributed vector into buffer, before the ifOp.
268 helper.buildStore(rewriter, loc, distributedVal, buffer);
269 // Load sequential vector from buffer, inside the ifOp.
270 rewriter.setInsertionPointToStart(ifOp.thenBlock());
271 bbArgReplacements.push_back(
272 helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
273 }
274
275 // Step 3. Insert sync after all the stores and before all the loads.
276 if (!warpOp.getArgs().empty()) {
277 rewriter.setInsertionPoint(ifOp);
278 options.warpSyncronizationFn(loc, rewriter, warpOp);
279 }
280
281 // Step 4. Move body of warpOp to ifOp.
282 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
283
284 // Step 5. Insert appropriate writes within scf.if and reads after the
285 // scf.if to transit the values returned by the op.
286 // TODO: at this point, we can reuse the shared memory from previous
287 // buffers.
288 SmallVector<Value> replacements;
289 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
290 Location yieldLoc = yieldOp.getLoc();
291 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
292 Value sequentialVal = it.value();
293 Value distributedVal = warpOp->getResult(it.index());
294 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
295 warpOp.getLaneid(), c0);
296
297 // Create buffer before the ifOp.
298 rewriter.setInsertionPoint(ifOp);
299 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
300 sequentialVal.getType());
301
302 // Store yielded value into buffer, inside the ifOp, before the
303 // terminator.
304 rewriter.setInsertionPoint(yieldOp);
305 helper.buildStore(rewriter, loc, sequentialVal, buffer);
306
307 // Load distributed value from buffer, after the warpOp.
308 rewriter.setInsertionPointAfter(ifOp);
309 // Result type and yielded value type are the same. This is a broadcast.
310 // E.g.:
311 // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
312 // gpu.yield %cst : f32
313 // }
314 // Both types are f32. The constant %cst is broadcasted to all lanes.
315 // This is described in more detail in the documentation of the op.
316 replacements.push_back(
317 helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
318 }
319
320 // Step 6. Insert sync after all the stores and before all the loads.
321 if (!yieldOp.getOperands().empty()) {
322 rewriter.setInsertionPointAfter(ifOp);
323 options.warpSyncronizationFn(loc, rewriter, warpOp);
324 }
325
326 // Step 7. Delete terminator and add empty scf.yield.
327 rewriter.eraseOp(yieldOp);
328 rewriter.setInsertionPointToEnd(ifOp.thenBlock());
329 scf::YieldOp::create(rewriter, yieldLoc);
330
331 // Compute replacements for WarpOp results.
332 rewriter.replaceOp(warpOp, replacements);
333
334 return success();
335 }
336
337private:
338 const WarpExecuteOnLane0LoweringOptions &options;
339};
340
341/// Return the distributed vector type based on the original type and the
342/// distribution map. The map is expected to have a dimension equal to the
343/// original type rank and should be a projection where the results are the
344/// distributed dimensions. If the number of results is zero there is no
345/// distribution (i.e. original type is returned).
346/// Otherwise, The number of results should be equal to the number
347/// of warp sizes which is currently limited to 1.
348/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
349/// and a warp size of 16 would distribute the second dimension (associated to
350/// d1) and return vector<16x2x64>
351static VectorType getDistributedType(VectorType originalType, AffineMap map,
352 int64_t warpSize) {
353 // If the map has zero results, return the original type.
354 if (map.getNumResults() == 0)
355 return originalType;
356 SmallVector<int64_t> targetShape(originalType.getShape());
357 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
358 unsigned position = map.getDimPosition(i);
359 if (targetShape[position] % warpSize != 0) {
360 if (warpSize % targetShape[position] != 0) {
361 return VectorType();
362 }
363 warpSize /= targetShape[position];
364 targetShape[position] = 1;
365 continue;
366 }
367 targetShape[position] = targetShape[position] / warpSize;
368 warpSize = 1;
369 break;
370 }
371 if (warpSize != 1) {
372 return VectorType();
373 }
374 VectorType targetType =
375 VectorType::get(targetShape, originalType.getElementType());
376 return targetType;
377}
378
379/// Given a warpOp that contains ops with regions, the corresponding op's
380/// "inner" region and the distributionMapFn, get all values used by the op's
381/// region that are defined within the warpOp, but outside the inner region.
382/// Return the set of values, their types and their distributed types.
383std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
385getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
386 DistributionMapFn distributionMapFn) {
387 llvm::SmallSetVector<Value, 32> escapingValues;
388 SmallVector<Type> escapingValueTypes;
389 SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
390 if (innerRegion.empty())
391 return {std::move(escapingValues), std::move(escapingValueTypes),
392 std::move(escapingValueDistTypes)};
393 mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
394 Operation *parent = operand->get().getParentRegion()->getParentOp();
395 if (warpOp->isAncestor(parent)) {
396 if (!escapingValues.insert(operand->get()))
397 return;
398 Type distType = operand->get().getType();
399 if (auto vecType = dyn_cast<VectorType>(distType)) {
400 AffineMap map = distributionMapFn(operand->get());
401 distType = getDistributedType(vecType, map,
402 map.isEmpty() ? 1 : warpOp.getWarpSize());
403 }
404 escapingValueTypes.push_back(operand->get().getType());
405 escapingValueDistTypes.push_back(distType);
406 }
407 });
408 return {std::move(escapingValues), std::move(escapingValueTypes),
409 std::move(escapingValueDistTypes)};
410}
411
412/// Distribute transfer_write ops based on the affine map returned by
413/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
414/// will not be distributed (it should be less than the warp size).
415///
416/// Example:
417/// ```
418/// %0 = gpu.warp_execute_on_lane_0(%id){
419/// ...
420/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
421/// gpu.yield
422/// }
423/// ```
424/// To
425/// ```
426/// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
427/// ...
428/// gpu.yield %v : vector<32xf32>
429/// }
430/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
431struct WarpOpTransferWrite : public WarpDistributionPattern {
432 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
433 unsigned maxNumElementsToExtract, PatternBenefit b = 1)
434 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
435 maxNumElementsToExtract(maxNumElementsToExtract) {}
436
437 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
438 /// are multiples of the distribution ratio are supported at the moment.
439 LogicalResult tryDistributeOp(RewriterBase &rewriter,
440 vector::TransferWriteOp writeOp,
441 WarpExecuteOnLane0Op warpOp) const {
442 VectorType writtenVectorType = writeOp.getVectorType();
443
444 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
445 // to separate it from the rest.
446 if (writtenVectorType.getRank() == 0)
447 return failure();
448
449 // 2. Compute the distributed type.
450 AffineMap map = distributionMapFn(writeOp.getVector());
451 VectorType targetType =
452 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
453 if (!targetType)
454 return failure();
455
456 // 2.5 Compute the distributed type for the new mask;
457 VectorType maskType;
458 if (writeOp.getMask()) {
459 // TODO: Distribution of masked writes with non-trivial permutation maps
460 // requires the distribution of the mask to elementwise match the
461 // distribution of the permuted written vector. Currently the details
462 // of which lane is responsible for which element is captured strictly
463 // by shape information on the warp op, and thus requires materializing
464 // the permutation in IR.
465 if (!writeOp.getPermutationMap().isMinorIdentity())
466 return failure();
467 maskType =
468 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
469 }
470
471 // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
472 // the rest.
473 vector::TransferWriteOp newWriteOp =
474 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
475
476 // 4. Reindex the write using the distribution map.
477 auto newWarpOp =
478 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
479
480 // Delinearize the lane id based on the way threads are divided across the
481 // vector. To get the number of threads per vector dimension, divide the
482 // sequential size by the distributed size along each dim.
483 rewriter.setInsertionPoint(newWriteOp);
484 SmallVector<OpFoldResult> delinearizedIdSizes;
485 for (auto [seqSize, distSize] :
486 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
487 assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
488 delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
489 }
490 SmallVector<Value> delinearized;
491 if (map.getNumResults() > 1) {
492 delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
493 rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
494 delinearizedIdSizes)
495 .getResults();
496 } else {
497 // If there is only one map result, we can elide the delinearization
498 // op and use the lane id directly.
499 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
500 }
501
502 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
503 Location loc = newWriteOp.getLoc();
504 SmallVector<Value> indices(newWriteOp.getIndices().begin(),
505 newWriteOp.getIndices().end());
506 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
507 AffineExpr d0, d1;
508 bindDims(newWarpOp.getContext(), d0, d1);
509 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
510 if (!indexExpr)
511 continue;
512 unsigned indexPos = indexExpr.getPosition();
513 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
514 Value laneId = delinearized[vectorPos];
515 auto scale =
516 rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
518 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
519 }
520 newWriteOp.getIndicesMutable().assign(indices);
521
522 return success();
523 }
524
525 /// Extract TransferWriteOps of vector<1x> into a separate warp op.
526 LogicalResult tryExtractOp(RewriterBase &rewriter,
527 vector::TransferWriteOp writeOp,
528 WarpExecuteOnLane0Op warpOp) const {
529 Location loc = writeOp.getLoc();
530 VectorType vecType = writeOp.getVectorType();
531
532 if (vecType.getNumElements() > maxNumElementsToExtract) {
533 return rewriter.notifyMatchFailure(
534 warpOp,
535 llvm::formatv(
536 "writes more elements ({0}) than allowed to extract ({1})",
537 vecType.getNumElements(), maxNumElementsToExtract));
538 }
539
540 // Do not process warp ops that contain only TransferWriteOps.
541 if (llvm::all_of(warpOp.getOps(),
542 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
543 return failure();
544
545 SmallVector<Value> yieldValues = {writeOp.getVector()};
546 SmallVector<Type> retTypes = {vecType};
547 SmallVector<size_t> newRetIndices;
548 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
549 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
550 rewriter.setInsertionPointAfter(newWarpOp);
551
552 // Create a second warp op that contains only writeOp.
553 auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc, TypeRange(),
554 newWarpOp.getLaneid(),
555 newWarpOp.getWarpSize());
556 Block &body = secondWarpOp.getBodyRegion().front();
557 rewriter.setInsertionPointToStart(&body);
558 auto newWriteOp =
559 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
560 newWriteOp.getValueToStoreMutable().assign(
561 newWarpOp.getResult(newRetIndices[0]));
562 rewriter.eraseOp(writeOp);
563 gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
564 return success();
565 }
566
567 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
568 PatternRewriter &rewriter) const override {
569 gpu::YieldOp yield = warpOp.getTerminator();
570 Operation *lastNode = yield->getPrevNode();
571 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
572 if (!writeOp)
573 return failure();
574
575 Value maybeMask = writeOp.getMask();
576 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
577 return writeOp.getVector() == value ||
578 (maybeMask && maybeMask == value) ||
579 warpOp.isDefinedOutsideOfRegion(value);
580 }))
581 return failure();
582
583 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
584 return success();
585
586 // Masked writes not supported for extraction.
587 if (writeOp.getMask())
588 return failure();
589
590 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
591 return success();
592
593 return failure();
594 }
595
596private:
597 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
598 /// execute op with the proper return type. The new write op is updated to
599 /// write the result of the new warp execute op. The old `writeOp` is deleted.
600 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
601 WarpExecuteOnLane0Op warpOp,
602 vector::TransferWriteOp writeOp,
603 VectorType targetType,
604 VectorType maybeMaskType) const {
605 assert(writeOp->getParentOp() == warpOp &&
606 "write must be nested immediately under warp");
607 OpBuilder::InsertionGuard g(rewriter);
608 SmallVector<size_t> newRetIndices;
609 WarpExecuteOnLane0Op newWarpOp;
610 if (maybeMaskType) {
612 rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
613 TypeRange{targetType, maybeMaskType}, newRetIndices);
614 } else {
616 rewriter, warpOp, ValueRange{{writeOp.getVector()}},
617 TypeRange{targetType}, newRetIndices);
618 }
619 rewriter.setInsertionPointAfter(newWarpOp);
620 auto newWriteOp =
621 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
622 rewriter.eraseOp(writeOp);
623 newWriteOp.getValueToStoreMutable().assign(
624 newWarpOp.getResult(newRetIndices[0]));
625 if (maybeMaskType)
626 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
627 return newWriteOp;
628 }
629
630 DistributionMapFn distributionMapFn;
631 unsigned maxNumElementsToExtract = 1;
632};
633
634/// Sink out elementwise op feeding into a warp op yield.
635/// ```
636/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
637/// ...
638/// %3 = arith.addf %1, %2 : vector<32xf32>
639/// gpu.yield %3 : vector<32xf32>
640/// }
641/// ```
642/// To
643/// ```
644/// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
645/// vector<1xf32>, vector<1xf32>) {
646/// ...
647/// %4 = arith.addf %2, %3 : vector<32xf32>
648/// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
649/// vector<32xf32>
650/// }
651/// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
652struct WarpOpElementwise : public WarpDistributionPattern {
653 using Base::Base;
654 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
655 PatternRewriter &rewriter) const override {
656 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
658 });
659 if (!yieldOperand)
660 return failure();
661
662 Operation *elementWise = yieldOperand->get().getDefiningOp();
663 unsigned operandIndex = yieldOperand->getOperandNumber();
664 Value distributedVal = warpOp.getResult(operandIndex);
665 SmallVector<Value> yieldValues;
666 SmallVector<Type> retTypes;
667 Location loc = warpOp.getLoc();
668 for (OpOperand &operand : elementWise->getOpOperands()) {
669 Type targetType;
670 if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
671 // If the result type is a vector, the operands must also be vectors.
672 auto operandType = cast<VectorType>(operand.get().getType());
673 targetType =
674 VectorType::get(vecType.getShape(), operandType.getElementType());
675 } else {
676 auto operandType = operand.get().getType();
677 assert(!isa<VectorType>(operandType) &&
678 "unexpected yield of vector from op with scalar result type");
679 targetType = operandType;
680 }
681 retTypes.push_back(targetType);
682 yieldValues.push_back(operand.get());
683 }
684 SmallVector<size_t> newRetIndices;
685 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
686 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
687 rewriter.setInsertionPointAfter(newWarpOp);
688 SmallVector<Value> newOperands(elementWise->getOperands().begin(),
689 elementWise->getOperands().end());
690 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
691 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
692 }
693 OpBuilder::InsertionGuard g(rewriter);
694 rewriter.setInsertionPointAfter(newWarpOp);
695 Operation *newOp = cloneOpWithOperandsAndTypes(
696 rewriter, loc, elementWise, newOperands,
697 {newWarpOp.getResult(operandIndex).getType()});
698 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
699 newOp->getResult(0));
700 return success();
701 }
702};
703
704/// Sink out splat constant op feeding into a warp op yield.
705/// ```
706/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
707/// ...
708/// %cst = arith.constant dense<2.0> : vector<32xf32>
709/// gpu.yield %cst : vector<32xf32>
710/// }
711/// ```
712/// To
713/// ```
714/// gpu.warp_execute_on_lane_0(%arg0 {
715/// ...
716/// }
717/// %0 = arith.constant dense<2.0> : vector<1xf32>
718struct WarpOpConstant : public WarpDistributionPattern {
719 using Base::Base;
720 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
721 PatternRewriter &rewriter) const override {
722 OpOperand *yieldOperand =
723 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
724 if (!yieldOperand)
725 return failure();
726 auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
727 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
728 if (!dense)
729 return failure();
730 // Notify the rewriter that the warp op is changing (see the comment on
731 // the WarpOpTransferRead pattern).
732 rewriter.startOpModification(warpOp);
733 unsigned operandIndex = yieldOperand->getOperandNumber();
734 Attribute scalarAttr = dense.getSplatValue<Attribute>();
735 auto newAttr = DenseElementsAttr::get(
736 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
737 Location loc = warpOp.getLoc();
738 rewriter.setInsertionPointAfter(warpOp);
739 Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
740 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
741 rewriter.finalizeOpModification(warpOp);
742 return success();
743 }
744};
745
746/// Sink out step op feeding into a warp op yield.
747/// Vector step op is treated similar to arith.constant, apart from
748/// the result that represents a sequence [0, vec_size).
749/// Due to the to vec_size == warp_size limitation,
750/// we can simply wrap the lane id into a vector (i.e., broadcast).
751/// Supporting vec_size != warp_size may involve preserving the step
752/// result and using additional arith ops (the exact details are TBD).
753/// ```
754/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
755/// ...
756/// %cst = vector.step : vector<32xindex>
757/// gpu.yield %cst : vector<1xindex>
758/// }
759/// ```
760/// To
761/// ```
762/// gpu.warp_execute_on_lane_0(%arg0) {
763/// ...
764/// }
765/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
766struct WarpOpStep final : public WarpDistributionPattern {
767 using Base::Base;
768 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
769 PatternRewriter &rewriter) const override {
770 OpOperand *yieldOperand =
771 getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
772 if (!yieldOperand)
773 return failure();
774 const unsigned operandIdx = yieldOperand->getOperandNumber();
775 auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
776 VectorType resTy = stepOp.getResult().getType();
777 if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize()))
778 return rewriter.notifyMatchFailure(
779 warpOp,
780 llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
781 resTy.getNumElements(), warpOp.getWarpSize()));
782 VectorType newVecTy =
783 cast<VectorType>(warpOp.getResult(operandIdx).getType());
784 rewriter.setInsertionPointAfter(warpOp);
785 Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
786 newVecTy, warpOp.getLaneid());
787 rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
788 return success();
789 }
790};
791
792/// Sink out transfer_read op feeding into a warp op yield.
793/// ```
794/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
795/// ...
796// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
797// vector<32xf32>
798/// gpu.yield %2 : vector<32xf32>
799/// }
800/// ```
801/// To
802/// ```
803/// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
804/// vector<1xf32>, vector<1xf32>) {
805/// ...
806/// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
807/// vector<32xf32> gpu.yield %2 : vector<32xf32>
808/// }
809/// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
810struct WarpOpTransferRead : public WarpDistributionPattern {
811 using Base::Base;
812 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
813 PatternRewriter &rewriter) const override {
814 // Try to find a distributable yielded read. Note that this pattern can
815 // still fail at the end after distribution, in which case this might have
816 // missed another distributable read.
817 OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
818 // Don't duplicate transfer_read ops when distributing.
819 return isa<vector::TransferReadOp>(op) && op->hasOneUse();
820 });
821 if (!operand)
822 return rewriter.notifyMatchFailure(
823 warpOp, "warp result is not a vector.transfer_read op");
824 auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
825
826 // Source must be defined outside of the region.
827 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
828 return rewriter.notifyMatchFailure(
829 read, "source must be defined outside of the region");
830
831 unsigned operandIndex = operand->getOperandNumber();
832 Value distributedVal = warpOp.getResult(operandIndex);
833
834 SmallVector<Value, 4> indices(read.getIndices().begin(),
835 read.getIndices().end());
836 auto sequentialType = cast<VectorType>(read.getResult().getType());
837 auto distributedType = cast<VectorType>(distributedVal.getType());
838 AffineMap map = calculateImplicitMap(sequentialType, distributedType);
839 AffineMap indexMap = map.compose(read.getPermutationMap());
840
841 // Try to delinearize the lane ID to match the rank expected for
842 // distribution.
843 SmallVector<Value> delinearizedIds;
844 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
845 distributedType.getShape(), warpOp.getWarpSize(),
846 warpOp.getLaneid(), delinearizedIds)) {
847 return rewriter.notifyMatchFailure(
848 read, "cannot delinearize lane ID for distribution");
849 }
850 assert(!delinearizedIds.empty() || map.getNumResults() == 0);
851
852 // Distribute indices and the mask (if present).
853 OpBuilder::InsertionGuard g(rewriter);
854 SmallVector<Value> additionalResults(indices.begin(), indices.end());
855 SmallVector<Type> additionalResultTypes(indices.size(),
856 rewriter.getIndexType());
857 additionalResults.push_back(read.getPadding());
858 additionalResultTypes.push_back(read.getPadding().getType());
859
860 bool hasMask = false;
861 if (read.getMask()) {
862 hasMask = true;
863 // TODO: Distribution of masked reads with non-trivial permutation maps
864 // requires the distribution of the mask to elementwise match the
865 // distribution of the permuted written vector. Currently the details
866 // of which lane is responsible for which element is captured strictly
867 // by shape information on the warp op, and thus requires materializing
868 // the permutation in IR.
869 if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
870 return rewriter.notifyMatchFailure(
871 read, "non-trivial permutation maps not supported");
872 VectorType maskType =
873 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
874 additionalResults.push_back(read.getMask());
875 additionalResultTypes.push_back(maskType);
876 }
877
878 SmallVector<size_t> newRetIndices;
879 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
880 rewriter, warpOp, additionalResults, additionalResultTypes,
881 newRetIndices);
882 distributedVal = newWarpOp.getResult(operandIndex);
883
884 // Distributed indices were appended first.
885 SmallVector<Value> newIndices;
886 for (int64_t i = 0, e = indices.size(); i < e; ++i)
887 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
888
889 rewriter.setInsertionPointAfter(newWarpOp);
890 for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
891 AffineExpr d0, d1;
892 bindDims(read.getContext(), d0, d1);
893 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
894 if (!indexExpr)
895 continue;
896 unsigned indexPos = indexExpr.getPosition();
897 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
898 int64_t scale = distributedType.getDimSize(vectorPos);
899 newIndices[indexPos] = affine::makeComposedAffineApply(
900 rewriter, read.getLoc(), d0 + scale * d1,
901 {newIndices[indexPos], delinearizedIds[vectorPos]});
902 }
903
904 // Distributed padding value was appended right after the indices.
905 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
906 // Distributed mask value was added at the end (if the op has a mask).
907 Value newMask =
908 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
909 : Value();
910 auto newRead = vector::TransferReadOp::create(
911 rewriter, read.getLoc(), distributedVal.getType(), read.getBase(),
912 newIndices, read.getPermutationMapAttr(), newPadding, newMask,
913 read.getInBoundsAttr());
914
915 rewriter.replaceAllUsesWith(distributedVal, newRead);
916 return success();
917 }
918};
919
920/// Remove any result that has no use along with the matching yieldOp operand.
921// TODO: Move this in WarpExecuteOnLane0Op canonicalization.
922struct WarpOpDeadResult : public WarpDistributionPattern {
923 using Base::Base;
924 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
925 PatternRewriter &rewriter) const override {
926 SmallVector<Type> newResultTypes;
927 newResultTypes.reserve(warpOp->getNumResults());
928 SmallVector<Value> newYieldValues;
929 newYieldValues.reserve(warpOp->getNumResults());
930 DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
931 DenseMap<OpResult, int64_t> dedupResultPositionMap;
932 gpu::YieldOp yield = warpOp.getTerminator();
933
934 // Some values may be yielded multiple times and correspond to multiple
935 // results. Deduplicating occurs by taking each result with its matching
936 // yielded value, and:
937 // 1. recording the unique first position at which the value with uses is
938 // yielded.
939 // 2. recording for the result, the first position at which the dedup'ed
940 // value is yielded.
941 // 3. skipping from the new result types / new yielded values any result
942 // that has no use or whose yielded value has already been seen.
943 for (OpResult result : warpOp.getResults()) {
944 if (result.use_empty())
945 continue;
946 Value yieldOperand = yield.getOperand(result.getResultNumber());
947 auto it = dedupYieldOperandPositionMap.insert(
948 std::make_pair(yieldOperand, newResultTypes.size()));
949 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
950 if (!it.second)
951 continue;
952 newResultTypes.push_back(result.getType());
953 newYieldValues.push_back(yieldOperand);
954 }
955 // No modification, exit early.
956 if (yield.getNumOperands() == newYieldValues.size())
957 return failure();
958 // Move the body of the old warpOp to a new warpOp.
959 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
960 rewriter, warpOp, newYieldValues, newResultTypes);
961
962 // Simplify the new warp op after dropping dead results.
963 newWarpOp.getBody()->walk([&](Operation *op) {
964 if (isOpTriviallyDead(op))
965 rewriter.eraseOp(op);
966 });
967
968 // Replace results of the old warpOp by the new, deduplicated results.
969 SmallVector<Value> newValues;
970 newValues.reserve(warpOp->getNumResults());
971 for (OpResult result : warpOp.getResults()) {
972 if (result.use_empty())
973 newValues.push_back(Value());
974 else
975 newValues.push_back(
976 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
977 }
978 rewriter.replaceOp(warpOp, newValues);
979 return success();
980 }
981};
982
983// If an operand is directly yielded out of the region we can forward it
984// directly and it doesn't need to go through the region.
985struct WarpOpForwardOperand : public WarpDistributionPattern {
986 using Base::Base;
987 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
988 PatternRewriter &rewriter) const override {
989 gpu::YieldOp yield = warpOp.getTerminator();
990 Value valForwarded;
991 unsigned resultIndex;
992 for (OpOperand &operand : yield->getOpOperands()) {
993 Value result = warpOp.getResult(operand.getOperandNumber());
994 if (result.use_empty())
995 continue;
996
997 // Assume all the values coming from above are uniform.
998 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
999 if (result.getType() != operand.get().getType())
1000 continue;
1001 valForwarded = operand.get();
1002 resultIndex = operand.getOperandNumber();
1003 break;
1004 }
1005 auto arg = dyn_cast<BlockArgument>(operand.get());
1006 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1007 continue;
1008 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1009 if (result.getType() != warpOperand.getType())
1010 continue;
1011 valForwarded = warpOperand;
1012 resultIndex = operand.getOperandNumber();
1013 break;
1014 }
1015 if (!valForwarded)
1016 return failure();
1017 // Notify the rewriter that the warp op is changing (see the comment on
1018 // the WarpOpTransferRead pattern).
1019 rewriter.startOpModification(warpOp);
1020 rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
1021 rewriter.finalizeOpModification(warpOp);
1022 return success();
1023 }
1024};
1025
1026struct WarpOpBroadcast : public WarpDistributionPattern {
1027 using Base::Base;
1028 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1029 PatternRewriter &rewriter) const override {
1030 OpOperand *operand =
1031 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1032 if (!operand)
1033 return failure();
1034 unsigned int operandNumber = operand->getOperandNumber();
1035 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
1036 Location loc = broadcastOp.getLoc();
1037 auto destVecType =
1038 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1039 Value broadcastSrc = broadcastOp.getSource();
1040 Type broadcastSrcType = broadcastSrc.getType();
1041
1042 // Check that the broadcast actually spans a set of values uniformly across
1043 // all threads. In other words, check that each thread can reconstruct
1044 // their own broadcast.
1045 // For that we simply check that the broadcast we want to build makes sense.
1046 if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
1047 vector::BroadcastableToResult::Success)
1048 return failure();
1049 SmallVector<size_t> newRetIndices;
1050 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1051 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1052 rewriter.setInsertionPointAfter(newWarpOp);
1053 Value broadcasted = vector::BroadcastOp::create(
1054 rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1055 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1056 broadcasted);
1057 return success();
1058 }
1059};
1060
1061/// Pattern to move shape cast out of the warp op. shape cast is basically a
1062/// no-op for warp distribution; we need to handle the shape though.
1063struct WarpOpShapeCast : public WarpDistributionPattern {
1064 using Base::Base;
1065 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1066 PatternRewriter &rewriter) const override {
1067 OpOperand *operand =
1068 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1069 if (!operand)
1070 return failure();
1071
1072 auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
1073
1074 unsigned int operandNumber = operand->getOperandNumber();
1075 auto castDistributedType =
1076 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1077 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1078 VectorType castResultType = castDistributedType;
1079
1080 // We expect the distributed type to have a smaller rank than the original
1081 // type. Prepend with size-one dimensions to make them the same.
1082 unsigned castDistributedRank = castDistributedType.getRank();
1083 unsigned castOriginalRank = castOriginalType.getRank();
1084 if (castDistributedRank < castOriginalRank) {
1085 SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1086 llvm::append_range(shape, castDistributedType.getShape());
1087 castDistributedType =
1088 VectorType::get(shape, castDistributedType.getElementType());
1089 }
1090
1091 SmallVector<size_t> newRetIndices;
1092 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1093 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1094 newRetIndices);
1095 rewriter.setInsertionPointAfter(newWarpOp);
1096 Value newCast = vector::ShapeCastOp::create(
1097 rewriter, oldCastOp.getLoc(), castResultType,
1098 newWarpOp->getResult(newRetIndices[0]));
1099 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1100 return success();
1101 }
1102};
1103
1104/// Sink out vector.create_mask / vector.constant_mask op feeding into a warp op
1105/// yield.
1106/// ```
1107/// %0 = ...
1108/// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1109/// ...
1110/// %mask = vector.create_mask %0 : vector<32xi1>
1111/// // or %mask = vector.constant_mask[2] : vector<32xi1>
1112/// gpu.yield %mask : vector<32xi1>
1113/// }
1114/// ```
1115/// To
1116/// ```
1117/// %0 = ...
1118/// gpu.warp_execute_on_lane_0(%arg0) {
1119/// ...
1120/// }
1121/// %cmp = arith.cmpi ult, %laneid, %0
1122/// %ub = arith.select %cmp, %c0, %c1
1123/// %1 = vector.create_mask %ub : vector<1xi1>
1124template <typename OpType,
1125 typename = std::enable_if_t<llvm::is_one_of<
1126 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
1127struct WarpOpCreateMask : public WarpDistributionPattern {
1128 using Base::Base;
1129 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1130 PatternRewriter &rewriter) const override {
1131 OpOperand *yieldOperand = getWarpResult(warpOp, (llvm::IsaPred<OpType>));
1132 if (!yieldOperand)
1133 return failure();
1134
1135 Operation *mask = yieldOperand->get().getDefiningOp<OpType>();
1136
1137 // Early exit if any values needed for calculating the new mask indices
1138 // are defined inside the warp op.
1139 if (mask->getOperands().size() &&
1140 !llvm::all_of(mask->getOperands(), [&](Value value) {
1141 return warpOp.isDefinedOutsideOfRegion(value);
1142 }))
1143 return failure();
1144
1145 Location loc = mask->getLoc();
1146 unsigned operandIndex = yieldOperand->getOperandNumber();
1147
1148 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1149 VectorType seqType = cast<VectorType>(mask->getResult(0).getType());
1150 ArrayRef<int64_t> seqShape = seqType.getShape();
1151 ArrayRef<int64_t> distShape = distType.getShape();
1152 SmallVector<Value> materializedOperands;
1153 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
1154 materializedOperands.append(mask->getOperands().begin(),
1155 mask->getOperands().end());
1156 } else {
1157 auto constantMaskOp = cast<vector::ConstantMaskOp>(mask);
1158 auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
1159 for (auto dimSize : dimSizes)
1160 materializedOperands.push_back(
1161 arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
1162 }
1163
1164 rewriter.setInsertionPointAfter(warpOp);
1165
1166 // Delinearize the lane ID for constructing the distributed mask sizes.
1167 SmallVector<Value> delinearizedIds;
1168 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1169 warpOp.getWarpSize(), warpOp.getLaneid(),
1170 delinearizedIds))
1171 return rewriter.notifyMatchFailure(
1172 mask, "cannot delinearize lane ID for distribution");
1173 assert(!delinearizedIds.empty());
1174
1175 // Notify the rewriter that the warp op is changing (see the comment on
1176 // the WarpOpTransferRead pattern).
1177 rewriter.startOpModification(warpOp);
1178
1179 AffineExpr s0, s1;
1180 bindSymbols(rewriter.getContext(), s0, s1);
1181 SmallVector<Value> newOperands;
1182 for (int i = 0, e = distShape.size(); i < e; ++i) {
1183 // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1184 // find the distance from the largest mask index owned by this lane to the
1185 // original mask size. `vector.create_mask` implicitly clamps mask
1186 // operands to the range [0, mask_vector_size[i]], or in other words, the
1187 // mask sizes are always in the range [0, mask_vector_size[i]).
1188 Value maskDimIdx = affine::makeComposedAffineApply(
1189 rewriter, loc, s1 - s0 * distShape[i],
1190 {delinearizedIds[i], materializedOperands[i]});
1191 newOperands.push_back(maskDimIdx);
1192 }
1193
1194 auto newMask =
1195 vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1196 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1197 rewriter.finalizeOpModification(warpOp);
1198 return success();
1199 }
1200};
1201
1202/// Sink out insert_strided_slice op feeding into a warp op yield.
1203/// ```
1204/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
1205/// ...
1206/// %src = ... : vector<4x32xf32>
1207/// %dest = ... : vector<8x32xf32>
1208/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1209/// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
1210/// gpu.yield %insert : vector<8x32xf32>
1211/// }
1212/// ```
1213/// To
1214/// ```
1215/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
1216/// vector<8x1xf32>) {
1217/// ...
1218/// %src = ... : vector<4x32xf32>
1219/// %dest = ... : vector<8x32xf32>
1220/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
1221/// }
1222/// %insert = vector.insert_strided_slice %0#0, %0#1,
1223/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
1224/// ```
1225/// NOTE: Current support assumes that both src and dest vectors are distributed
1226/// to lanes and sinking the insert op does not require any cross lane
1227/// communication.
1228struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
1229 using Base::Base;
1230 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1231 PatternRewriter &rewriter) const override {
1232 OpOperand *operand =
1233 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1234 if (!operand)
1235 return failure();
1236 unsigned int operandNumber = operand->getOperandNumber();
1237 auto insertOp =
1238 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1239 auto distributedType =
1240 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1241 // Distributed type must be 2D or higher.
1242 // TODO: Support 1D distributed types.
1243 if (distributedType.getRank() < 2)
1244 return rewriter.notifyMatchFailure(
1245 insertOp, "result vector type must be 2D or higher");
1246 // Find the distributed dimension of the dest vector. There should be
1247 // exactly one.
1248 auto yieldedType = cast<VectorType>(operand->get().getType());
1249 int64_t destDistributedDim =
1250 getDistributedDim(yieldedType, distributedType);
1251 assert(destDistributedDim != -1 && "could not find distributed dimension");
1252
1253 VectorType srcType = insertOp.getSourceVectorType();
1254 VectorType destType = insertOp.getDestVectorType();
1255 // Currently we require that both source (kD) and dest (nD) vectors are
1256 // distributed. This requires that distributedDim (d) is contained in the
1257 // last k dims of the dest vector (d >= n - k).
1258 // TODO: Add support for case where source vector is not distributed.
1259 int64_t sourceDistributedDim =
1260 destDistributedDim - (destType.getRank() - srcType.getRank());
1261 if (sourceDistributedDim < 0)
1262 return rewriter.notifyMatchFailure(
1263 insertOp,
1264 "distributed dimension must be in the last k dims of dest vector");
1265 // Distributed dimension must be fully inserted.
1266 if (srcType.getDimSize(sourceDistributedDim) !=
1267 destType.getDimSize(destDistributedDim))
1268 return rewriter.notifyMatchFailure(
1269 insertOp, "distributed dimension must be fully inserted");
1270 SmallVector<int64_t> newSourceDistShape(
1271 insertOp.getSourceVectorType().getShape());
1272 newSourceDistShape[sourceDistributedDim] =
1273 distributedType.getDimSize(destDistributedDim);
1274 auto newSourceTy =
1275 VectorType::get(newSourceDistShape, distributedType.getElementType());
1276 VectorType newDestTy = distributedType;
1277 SmallVector<size_t> newRetIndices;
1278 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1279 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1280 {newSourceTy, newDestTy}, newRetIndices);
1281 rewriter.setInsertionPointAfter(newWarpOp);
1282 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1283 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1284 // Create a new insert strided slice op that inserts distributed source into
1285 // distributed dest.
1286 Value newInsert = vector::InsertStridedSliceOp::create(
1287 rewriter, insertOp.getLoc(), distributedDest.getType(),
1288 distributedSource, distributedDest, insertOp.getOffsets(),
1289 insertOp.getStrides());
1290 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
1291 return success();
1292 }
1293};
1294
1295/// Sink out extract_strided_slice op feeding into a warp op yield.
1296/// ```
1297/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
1298/// ...
1299/// %src = ... : vector<64x32xf32>
1300/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1301/// strides = [1] : vector<64x32xf32> to vector<16x32xf32>
1302/// gpu.yield %extract : vector<16x32xf32>
1303/// }
1304/// ```
1305/// To
1306/// ```
1307/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
1308/// ...
1309/// %src = ... : vector<64x32xf32>
1310/// gpu.yield %src : vector<64x32xf32>
1311/// }
1312/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1313/// strides = [1] : vector<64x1xf32> to vector<16x1xf32>
1314/// ```
1315/// NOTE: Current support assumes that the extraction happens only on non
1316/// distributed dimensions (does not require cross lane communication).
1317struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
1318 using Base::Base;
1319 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1320 PatternRewriter &rewriter) const override {
1321 OpOperand *operand =
1322 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1323 if (!operand)
1324 return failure();
1325 unsigned int operandNumber = operand->getOperandNumber();
1326 auto extractOp =
1327 operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
1328 auto distributedType =
1329 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1330 // Distributed type must be 2D or higher.
1331 // TODO: Support 1D distributed types.
1332 if (distributedType.getRank() < 2)
1333 return rewriter.notifyMatchFailure(
1334 extractOp, "result vector type must be 2D or higher");
1335
1336 // Find the distributed dimension. There should be exactly one.
1337 auto yieldedType = cast<VectorType>(operand->get().getType());
1338 int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1339 assert(distributedDim != -1 && "could not find distributed dimension");
1340
1341 int64_t numOfExtractedDims =
1342 static_cast<int64_t>(extractOp.getSizes().size());
1343 // If the distributed dim is included in the extracted dims, then we make
1344 // sure distributed dim is fully extracted. If distributed dim is not
1345 // included in extracted dims, it is guaranteed to be fully extracted (i.e.
1346 // distributed dim comes after all the extracted dims)
1347 // TODO: Partial extraction from distributed dimension require cross lane
1348 // communication.
1349 if (distributedDim < numOfExtractedDims) {
1350 int64_t distributedDimOffset =
1351 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1352 .getInt();
1353 int64_t distributedDimSize =
1354 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1355 .getInt();
1356 if (distributedDimOffset != 0 ||
1357 distributedDimSize != yieldedType.getDimSize(distributedDim))
1358 return rewriter.notifyMatchFailure(
1359 extractOp, "distributed dimension must be fully extracted");
1360 }
1361 SmallVector<int64_t> newDistributedShape(
1362 extractOp.getSourceVectorType().getShape());
1363 newDistributedShape[distributedDim] =
1364 distributedType.getDimSize(distributedDim);
1365 auto newDistributedType =
1366 VectorType::get(newDistributedShape, distributedType.getElementType());
1367 SmallVector<size_t> newRetIndices;
1368 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1369 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1370 newRetIndices);
1371 rewriter.setInsertionPointAfter(newWarpOp);
1372 SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1373 extractOp.getSizes(), [](Attribute attr) { return attr; });
1374 // Update the distributed sizes to match the distributed type.
1375 if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
1376 distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1377 distributedType.getDimSize(distributedDim));
1378
1379 // Create a new extract strided slice op that extracts from the
1380 // distributed vector.
1381 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1382 Value newExtract = vector::ExtractStridedSliceOp::create(
1383 rewriter, extractOp.getLoc(), distributedType, distributedVec,
1384 extractOp.getOffsets(),
1385 ArrayAttr::get(rewriter.getContext(), distributedSizes),
1386 extractOp.getStrides());
1387 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1388 newExtract);
1389 return success();
1390 }
1391};
1392
1393/// Pattern to move out vector.extract of single element vector. Those don't
1394/// need to be distributed and can just be propagated outside of the region.
1395struct WarpOpExtract : public WarpDistributionPattern {
1396 using Base::Base;
1397 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1398 PatternRewriter &rewriter) const override {
1399 OpOperand *operand =
1400 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1401 if (!operand)
1402 return failure();
1403 unsigned int operandNumber = operand->getOperandNumber();
1404 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1405 VectorType extractSrcType = extractOp.getSourceVectorType();
1406 Location loc = extractOp.getLoc();
1407
1408 // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1409 if (extractSrcType.getRank() <= 1) {
1410 return failure();
1411 }
1412
1413 // All following cases are 2d or higher dimensional source vectors.
1414
1415 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1416 // There is no distribution, this is a broadcast. Simply move the extract
1417 // out of the warp op.
1418 // TODO: This could be optimized. E.g., in case of a scalar result, let
1419 // one lane extract and shuffle the result to all other lanes (same as
1420 // the 1d case).
1421 SmallVector<size_t> newRetIndices;
1422 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1423 rewriter, warpOp, {extractOp.getSource()},
1424 {extractOp.getSourceVectorType()}, newRetIndices);
1425 rewriter.setInsertionPointAfter(newWarpOp);
1426 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1427 // Extract from distributed vector.
1428 Value newExtract = vector::ExtractOp::create(
1429 rewriter, loc, distributedVec, extractOp.getMixedPosition());
1430 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1431 newExtract);
1432 return success();
1433 }
1434
1435 // Find the distributed dimension. There should be exactly one.
1436 auto distributedType =
1437 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1438 auto yieldedType = cast<VectorType>(operand->get().getType());
1439 int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1440 assert(distributedDim != -1 && "could not find distributed dimension");
1441 (void)distributedDim;
1442
1443 // Yield source vector from warp op.
1444 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1445 for (int i = 0; i < distributedType.getRank(); ++i)
1446 newDistributedShape[i + extractOp.getNumIndices()] =
1447 distributedType.getDimSize(i);
1448 auto newDistributedType =
1449 VectorType::get(newDistributedShape, distributedType.getElementType());
1450 SmallVector<size_t> newRetIndices;
1451 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1452 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1453 newRetIndices);
1454 rewriter.setInsertionPointAfter(newWarpOp);
1455 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1456 // Extract from distributed vector.
1457 Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1458 extractOp.getMixedPosition());
1459 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1460 newExtract);
1461 return success();
1462 }
1463};
1464
1465/// Pattern to move out vector.extract with a scalar result.
1466/// Only supports 1-D and 0-D sources for now.
1467struct WarpOpExtractScalar : public WarpDistributionPattern {
1468 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1469 PatternBenefit b = 1)
1470 : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1471 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1472 PatternRewriter &rewriter) const override {
1473 OpOperand *operand =
1474 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1475 if (!operand)
1476 return failure();
1477 unsigned int operandNumber = operand->getOperandNumber();
1478 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1479 VectorType extractSrcType = extractOp.getSourceVectorType();
1480 // Only supports 1-D or 0-D sources for now.
1481 if (extractSrcType.getRank() > 1) {
1482 return rewriter.notifyMatchFailure(
1483 extractOp, "only 0-D or 1-D source supported for now");
1484 }
1485 // TODO: Supported shuffle types should be parameterizable, similar to
1486 // `WarpShuffleFromIdxFn`.
1487 if (!extractSrcType.getElementType().isF32() &&
1488 !extractSrcType.getElementType().isInteger(32))
1489 return rewriter.notifyMatchFailure(
1490 extractOp, "only f32/i32 element types are supported");
1491 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1492 Type elType = extractSrcType.getElementType();
1493 VectorType distributedVecType;
1494 if (!is0dOrVec1Extract) {
1495 assert(extractSrcType.getRank() == 1 &&
1496 "expected that extract src rank is 0 or 1");
1497 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1498 return failure();
1499 int64_t elementsPerLane =
1500 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1501 distributedVecType = VectorType::get({elementsPerLane}, elType);
1502 } else {
1503 distributedVecType = extractSrcType;
1504 }
1505 // Yield source vector and position (if present) from warp op.
1506 SmallVector<Value> additionalResults{extractOp.getSource()};
1507 SmallVector<Type> additionalResultTypes{distributedVecType};
1508 additionalResults.append(
1509 SmallVector<Value>(extractOp.getDynamicPosition()));
1510 additionalResultTypes.append(
1511 SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1512
1513 Location loc = extractOp.getLoc();
1514 SmallVector<size_t> newRetIndices;
1515 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1516 rewriter, warpOp, additionalResults, additionalResultTypes,
1517 newRetIndices);
1518 rewriter.setInsertionPointAfter(newWarpOp);
1519 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1520
1521 // 0d extract: The new warp op broadcasts the source vector to all lanes.
1522 // All lanes extract the scalar.
1523 if (is0dOrVec1Extract) {
1524 Value newExtract;
1525 SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1526 newExtract =
1527 vector::ExtractOp::create(rewriter, loc, distributedVec, indices);
1528 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1529 newExtract);
1530 return success();
1531 }
1532
1533 int64_t staticPos = extractOp.getStaticPosition()[0];
1534 OpFoldResult pos = ShapedType::isDynamic(staticPos)
1535 ? (newWarpOp->getResult(newRetIndices[1]))
1536 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1537 // 1d extract: Distribute the source vector. One lane extracts and shuffles
1538 // the value to all other lanes.
1539 int64_t elementsPerLane = distributedVecType.getShape()[0];
1540 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1541 // tid of extracting thread: pos / elementsPerLane
1542 Value broadcastFromTid = affine::makeComposedAffineApply(
1543 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1544 // Extract at position: pos % elementsPerLane
1545 Value newPos =
1546 elementsPerLane == 1
1547 ? arith::ConstantIndexOp::create(rewriter, loc, 0).getResult()
1548 : affine::makeComposedAffineApply(rewriter, loc,
1549 sym0 % elementsPerLane, pos);
1550 Value extracted =
1551 vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1552
1553 // Shuffle the extracted value to all lanes.
1554 Value shuffled = warpShuffleFromIdxFn(
1555 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1556 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1557 return success();
1558 }
1559
1560private:
1561 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1562};
1563
1564/// Pattern to move out vector.insert with a scalar input.
1565/// Only supports 1-D and 0-D destinations for now.
1566struct WarpOpInsertScalar : public WarpDistributionPattern {
1567 using Base::Base;
1568 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1569 PatternRewriter &rewriter) const override {
1570 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1571 if (!operand)
1572 return failure();
1573 unsigned int operandNumber = operand->getOperandNumber();
1574 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1575 VectorType vecType = insertOp.getDestVectorType();
1576 VectorType distrType =
1577 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1578
1579 // Only supports 1-D or 0-D destinations for now.
1580 if (vecType.getRank() > 1) {
1581 return rewriter.notifyMatchFailure(
1582 insertOp, "only 0-D or 1-D source supported for now");
1583 }
1584
1585 // Yield destination vector, source scalar and position from warp op.
1586 SmallVector<Value> additionalResults{insertOp.getDest(),
1587 insertOp.getValueToStore()};
1588 SmallVector<Type> additionalResultTypes{
1589 distrType, insertOp.getValueToStore().getType()};
1590 additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1591 additionalResultTypes.append(
1592 SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1593
1594 Location loc = insertOp.getLoc();
1595 SmallVector<size_t> newRetIndices;
1596 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1597 rewriter, warpOp, additionalResults, additionalResultTypes,
1598 newRetIndices);
1599 rewriter.setInsertionPointAfter(newWarpOp);
1600 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1601 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1602 rewriter.setInsertionPointAfter(newWarpOp);
1603
1604 OpFoldResult pos;
1605 if (vecType.getRank() != 0) {
1606 int64_t staticPos = insertOp.getStaticPosition()[0];
1607 pos = ShapedType::isDynamic(staticPos)
1608 ? (newWarpOp->getResult(newRetIndices[2]))
1609 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1610 }
1611
1612 // This condition is always true for 0-d vectors.
1613 if (vecType == distrType) {
1614 Value newInsert;
1615 SmallVector<OpFoldResult> indices;
1616 if (pos) {
1617 indices.push_back(pos);
1618 }
1619 newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1620 distributedVec, indices);
1621 // Broadcast: Simply move the vector.insert op out.
1622 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1623 newInsert);
1624 return success();
1625 }
1626
1627 // This is a distribution. Only one lane should insert.
1628 int64_t elementsPerLane = distrType.getShape()[0];
1629 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1630 // tid of extracting thread: pos / elementsPerLane
1631 Value insertingLane = affine::makeComposedAffineApply(
1632 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1633 // Insert position: pos % elementsPerLane
1634 OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
1635 rewriter, loc, sym0 % elementsPerLane, pos);
1636 Value isInsertingLane =
1637 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1638 newWarpOp.getLaneid(), insertingLane);
1639 Value newResult =
1640 scf::IfOp::create(
1641 rewriter, loc, isInsertingLane,
1642 /*thenBuilder=*/
1643 [&](OpBuilder &builder, Location loc) {
1644 Value newInsert = vector::InsertOp::create(
1645 builder, loc, newSource, distributedVec, newPos);
1646 scf::YieldOp::create(builder, loc, newInsert);
1647 },
1648 /*elseBuilder=*/
1649 [&](OpBuilder &builder, Location loc) {
1650 scf::YieldOp::create(builder, loc, distributedVec);
1651 })
1652 .getResult(0);
1653 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1654 return success();
1655 }
1656};
1657
1658struct WarpOpInsert : public WarpDistributionPattern {
1659 using Base::Base;
1660 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1661 PatternRewriter &rewriter) const override {
1662 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1663 if (!operand)
1664 return failure();
1665 unsigned int operandNumber = operand->getOperandNumber();
1666 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1667 Location loc = insertOp.getLoc();
1668
1669 // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1670 if (insertOp.getDestVectorType().getRank() <= 1) {
1671 return failure();
1672 }
1673
1674 // All following cases are 2d or higher dimensional source vectors.
1675
1676 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1677 // There is no distribution, this is a broadcast. Simply move the insert
1678 // out of the warp op.
1679 SmallVector<size_t> newRetIndices;
1680 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1681 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1682 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1683 newRetIndices);
1684 rewriter.setInsertionPointAfter(newWarpOp);
1685 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1686 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1687 Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1688 distributedDest,
1689 insertOp.getMixedPosition());
1690 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1691 newResult);
1692 return success();
1693 }
1694
1695 // Find the distributed dimension. There should be exactly one.
1696 auto distrDestType =
1697 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1698 auto yieldedType = cast<VectorType>(operand->get().getType());
1699 int64_t distrDestDim = -1;
1700 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1701 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1702 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1703 // support distributing multiple dimensions in the future.
1704 assert(distrDestDim == -1 && "found multiple distributed dims");
1705 distrDestDim = i;
1706 }
1707 }
1708 assert(distrDestDim != -1 && "could not find distributed dimension");
1709
1710 // Compute the distributed source vector type.
1711 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1712 SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1713 // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1714 // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1715 // insert a smaller vector<3xf32>.
1716 // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1717 // case, one lane will insert the source vector<96xf32>. The other
1718 // lanes will not do anything.
1719 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1720 if (distrSrcDim >= 0)
1721 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1722 auto distrSrcType =
1723 VectorType::get(distrSrcShape, distrDestType.getElementType());
1724
1725 // Yield source and dest vectors from warp op.
1726 SmallVector<size_t> newRetIndices;
1727 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1728 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1729 {distrSrcType, distrDestType}, newRetIndices);
1730 rewriter.setInsertionPointAfter(newWarpOp);
1731 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1732 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1733
1734 // Insert into the distributed vector.
1735 Value newResult;
1736 if (distrSrcDim >= 0) {
1737 // Every lane inserts a small piece.
1738 newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1739 distributedDest,
1740 insertOp.getMixedPosition());
1741 } else {
1742 // One lane inserts the entire source vector.
1743 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1744 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1745 SmallVector<int64_t> newPos = getAsIntegers(pos);
1746 // tid of inserting lane: pos / elementsPerLane
1747 Value insertingLane = arith::ConstantIndexOp::create(
1748 rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1749 Value isInsertingLane =
1750 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1751 newWarpOp.getLaneid(), insertingLane);
1752 // Insert position: pos % elementsPerLane
1753 newPos[distrDestDim] %= elementsPerLane;
1754 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1755 Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1756 distributedDest, newPos);
1757 scf::YieldOp::create(builder, loc, newInsert);
1758 };
1759 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1760 scf::YieldOp::create(builder, loc, distributedDest);
1761 };
1762 newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1763 /*thenBuilder=*/insertingBuilder,
1764 /*elseBuilder=*/nonInsertingBuilder)
1765 .getResult(0);
1766 }
1767
1768 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1769 return success();
1770 }
1771};
1772
1773/// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
1774/// the scf.if is the last operation in the region so that it doesn't
1775/// change the order of execution. This creates a new scf.if after the
1776/// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
1777/// the "inner" WarpExecuteOnLane0Op. Example:
1778/// ```
1779/// gpu.warp_execute_on_lane_0(%laneid)[32] {
1780/// %payload = ... : vector<32xindex>
1781/// scf.if %pred {
1782/// vector.store %payload, %buffer[%idx] : memref<128xindex>,
1783/// vector<32xindex>
1784/// }
1785/// gpu.yield
1786/// }
1787/// ```
1788/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
1789/// %payload = ... : vector<32xindex>
1790/// gpu.yield %payload : vector<32xindex>
1791/// }
1792/// scf.if %pred {
1793/// gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
1794/// ^bb0(%arg1: vector<32xindex>):
1795/// vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
1796/// }
1797/// }
1798/// ```
1799struct WarpOpScfIfOp : public WarpDistributionPattern {
1800 WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1801 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1802 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1803 PatternRewriter &rewriter) const override {
1804 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1805 // Only pick up `IfOp` if it is the last op in the region.
1806 Operation *lastNode = warpOpYield->getPrevNode();
1807 auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1808 if (!ifOp)
1809 return failure();
1810
1811 // The current `WarpOp` can yield two types of values:
1812 // 1. Not results of `IfOp`:
1813 // Preserve them in the new `WarpOp`.
1814 // Collect their yield index to remap the usages.
1815 // 2. Results of `IfOp`:
1816 // They are not part of the new `WarpOp` results.
1817 // Map current warp's yield operand index to `IfOp` result idx.
1818 SmallVector<Value> nonIfYieldValues;
1819 SmallVector<unsigned> nonIfYieldIndices;
1820 llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1821 llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1822 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1823 const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
1824 if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
1825 nonIfYieldValues.push_back(yieldOperand.get());
1826 nonIfYieldIndices.push_back(yieldOperandIdx);
1827 continue;
1828 }
1829 OpResult ifResult = cast<OpResult>(yieldOperand.get());
1830 const unsigned ifResultIdx = ifResult.getResultNumber();
1831 ifResultMapping[yieldOperandIdx] = ifResultIdx;
1832 // If this `ifOp` result is vector type and it is yielded by the
1833 // `WarpOp`, we keep track the distributed type for this result.
1834 if (!isa<VectorType>(ifResult.getType()))
1835 continue;
1836 VectorType distType =
1837 cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1838 ifResultDistTypes[ifResultIdx] = distType;
1839 }
1840
1841 // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
1842 // them
1843 auto [escapingValuesThen, escapingValueInputTypesThen,
1844 escapingValueDistTypesThen] =
1845 getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1846 distributionMapFn);
1847 auto [escapingValuesElse, escapingValueInputTypesElse,
1848 escapingValueDistTypesElse] =
1849 getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1850 distributionMapFn);
1851 if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1852 llvm::is_contained(escapingValueDistTypesElse, Type{}))
1853 return failure();
1854
1855 // The new `WarpOp` groups yields values in following order:
1856 // 1. Branch condition
1857 // 2. Escaping values then branch
1858 // 3. Escaping values else branch
1859 // 4. All non-`ifOp` yielded values.
1860 SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1861 newWarpOpYieldValues.append(escapingValuesThen.begin(),
1862 escapingValuesThen.end());
1863 newWarpOpYieldValues.append(escapingValuesElse.begin(),
1864 escapingValuesElse.end());
1865 SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1866 newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1867 escapingValueDistTypesThen.end());
1868 newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1869 escapingValueDistTypesElse.end());
1870
1871 for (auto [idx, val] :
1872 llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1873 newWarpOpYieldValues.push_back(val);
1874 newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1875 }
1876 // Replace the old `WarpOp` with the new one that has additional yield
1877 // values and types.
1878 SmallVector<size_t> newIndices;
1879 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1880 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
1881 // `ifOp` returns the result of the inner warp op.
1882 SmallVector<Type> newIfOpDistResTypes;
1883 for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1884 Type distType = cast<Value>(res).getType();
1885 if (auto vecType = dyn_cast<VectorType>(distType)) {
1886 AffineMap map = distributionMapFn(cast<Value>(res));
1887 // Fallback to affine map if the dist result was not previously recorded
1888 distType = ifResultDistTypes.count(i)
1889 ? ifResultDistTypes[i]
1890 : getDistributedType(
1891 vecType, map,
1892 map.isEmpty() ? 1 : newWarpOp.getWarpSize());
1893 }
1894 newIfOpDistResTypes.push_back(distType);
1895 }
1896 // Create a new `IfOp` outside the new `WarpOp` region.
1897 OpBuilder::InsertionGuard g(rewriter);
1898 rewriter.setInsertionPointAfter(newWarpOp);
1899 auto newIfOp = scf::IfOp::create(
1900 rewriter, ifOp.getLoc(), newIfOpDistResTypes,
1901 newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
1902 static_cast<bool>(ifOp.elseBlock()));
1903 auto encloseRegionInWarpOp =
1904 [&](Block *oldIfBranch, Block *newIfBranch,
1905 llvm::SmallSetVector<Value, 32> &escapingValues,
1906 SmallVector<Type> &escapingValueInputTypes,
1907 size_t warpResRangeStart) {
1908 OpBuilder::InsertionGuard g(rewriter);
1909 if (!newIfBranch)
1910 return;
1911 rewriter.setInsertionPointToStart(newIfBranch);
1912 llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1913 SmallVector<Value> innerWarpInputVals;
1914 SmallVector<Type> innerWarpInputTypes;
1915 for (size_t i = 0; i < escapingValues.size();
1916 ++i, ++warpResRangeStart) {
1917 innerWarpInputVals.push_back(
1918 newWarpOp.getResult(newIndices[warpResRangeStart]));
1919 escapeValToBlockArgIndex[escapingValues[i]] =
1920 innerWarpInputTypes.size();
1921 innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1922 }
1923 auto innerWarp = WarpExecuteOnLane0Op::create(
1924 rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1925 newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1926 innerWarpInputVals, innerWarpInputTypes);
1927
1928 innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
1929 innerWarp.getWarpRegion().addArguments(
1930 innerWarpInputTypes,
1931 SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1932
1933 SmallVector<Value> yieldOperands;
1934 for (Value operand : oldIfBranch->getTerminator()->getOperands())
1935 yieldOperands.push_back(operand);
1936 rewriter.eraseOp(oldIfBranch->getTerminator());
1937
1938 rewriter.setInsertionPointToEnd(innerWarp.getBody());
1939 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1940 rewriter.setInsertionPointAfter(innerWarp);
1941 scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1942
1943 // Update any users of escaping values that were forwarded to the
1944 // inner `WarpOp`. These values are arguments of the inner `WarpOp`.
1945 innerWarp.walk([&](Operation *op) {
1946 for (OpOperand &operand : op->getOpOperands()) {
1947 auto it = escapeValToBlockArgIndex.find(operand.get());
1948 if (it == escapeValToBlockArgIndex.end())
1949 continue;
1950 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1951 }
1952 });
1953 mlir::vector::moveScalarUniformCode(innerWarp);
1954 };
1955 encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
1956 &newIfOp.getThenRegion().front(), escapingValuesThen,
1957 escapingValueInputTypesThen, 1);
1958 if (!ifOp.getElseRegion().empty())
1959 encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
1960 &newIfOp.getElseRegion().front(),
1961 escapingValuesElse, escapingValueInputTypesElse,
1962 1 + escapingValuesThen.size());
1963 // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
1964 // result.
1965 for (auto [origIdx, newIdx] : ifResultMapping)
1966 rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
1967 newIfOp.getResult(newIdx), newIfOp);
1968 return success();
1969 }
1970
1971private:
1972 DistributionMapFn distributionMapFn;
1973};
1974
1975/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1976/// the scf.ForOp is the last operation in the region so that it doesn't
1977/// change the order of execution. This creates a new scf.for region after the
1978/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1979/// WarpExecuteOnLane0Op region. Example:
1980/// ```
1981/// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1982/// ...
1983/// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1984/// -> (vector<128xf32>) {
1985/// ...
1986/// scf.yield %r : vector<128xf32>
1987/// }
1988/// gpu.yield %v1 : vector<128xf32>
1989/// }
1990/// ```
1991/// To:
1992/// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1993/// ...
1994/// gpu.yield %v : vector<128xf32>
1995/// }
1996/// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1997/// -> (vector<4xf32>) {
1998/// %iw = gpu.warp_execute_on_lane_0(%laneid)
1999/// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
2000/// ^bb0(%arg: vector<128xf32>):
2001/// ...
2002/// gpu.yield %ir : vector<128xf32>
2003/// }
2004/// scf.yield %iw : vector<4xf32>
2005/// }
2006/// ```
2007struct WarpOpScfForOp : public WarpDistributionPattern {
2008
2009 WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
2010 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
2011 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2012 PatternRewriter &rewriter) const override {
2013 gpu::YieldOp warpOpYield = warpOp.getTerminator();
2014 // Only pick up `ForOp` if it is the last op in the region.
2015 Operation *lastNode = warpOpYield->getPrevNode();
2016 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
2017 if (!forOp)
2018 return failure();
2019 // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
2020 // Those Values need to be returned by the new warp op.
2021 auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2022 getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2023 distributionMapFn);
2024 if (llvm::is_contained(escapingValueDistTypes, Type{}))
2025 return failure();
2026 // `WarpOp` can yield two types of values:
2027 // 1. Values that are not results of the `ForOp`:
2028 // These values must also be yielded by the new `WarpOp`. Also, we need
2029 // to record the index mapping for these values to replace them later.
2030 // 2. Values that are results of the `ForOp`:
2031 // In this case, we record the index mapping between the `WarpOp` result
2032 // index and matching `ForOp` result index.
2033 // Additionally, we keep track of the distributed types for all `ForOp`
2034 // vector results.
2035 SmallVector<Value> nonForYieldedValues;
2036 SmallVector<unsigned> nonForResultIndices;
2037 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
2038 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
2039 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
2040 // Yielded value is not a result of the forOp.
2041 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
2042 nonForYieldedValues.push_back(yieldOperand.get());
2043 nonForResultIndices.push_back(yieldOperand.getOperandNumber());
2044 continue;
2045 }
2046 OpResult forResult = cast<OpResult>(yieldOperand.get());
2047 unsigned int forResultNumber = forResult.getResultNumber();
2048 forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber;
2049 // If this `ForOp` result is vector type and it is yielded by the
2050 // `WarpOp`, we keep track the distributed type for this result.
2051 if (!isa<VectorType>(forResult.getType()))
2052 continue;
2053 VectorType distType = cast<VectorType>(
2054 warpOp.getResult(yieldOperand.getOperandNumber()).getType());
2055 forResultDistTypes[forResultNumber] = distType;
2056 }
2057
2058 // Newly created `WarpOp` will yield values in following order:
2059 // 1. Loop bounds.
2060 // 2. All init args of the `ForOp`.
2061 // 3. All escaping values.
2062 // 4. All non-`ForOp` yielded values.
2063 SmallVector<Value> newWarpOpYieldValues;
2064 SmallVector<Type> newWarpOpDistTypes;
2065 newWarpOpYieldValues.insert(
2066 newWarpOpYieldValues.end(),
2067 {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
2068 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2069 {forOp.getLowerBound().getType(),
2070 forOp.getUpperBound().getType(),
2071 forOp.getStep().getType()});
2072 for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
2073 newWarpOpYieldValues.push_back(initArg);
2074 // Compute the distributed type for this init arg.
2075 Type distType = initArg.getType();
2076 if (auto vecType = dyn_cast<VectorType>(distType)) {
2077 // If the `ForOp` result corresponds to this init arg is already yielded
2078 // we can get the distributed type from `forResultDistTypes` map.
2079 // Otherwise, we compute it using distributionMapFn.
2080 AffineMap map = distributionMapFn(initArg);
2081 distType =
2082 forResultDistTypes.count(i)
2083 ? forResultDistTypes[i]
2084 : getDistributedType(vecType, map,
2085 map.isEmpty() ? 1 : warpOp.getWarpSize());
2086 }
2087 newWarpOpDistTypes.push_back(distType);
2088 }
2089 // Insert escaping values and their distributed types.
2090 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
2091 escapingValues.begin(), escapingValues.end());
2092 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2093 escapingValueDistTypes.begin(),
2094 escapingValueDistTypes.end());
2095 // Next, we insert all non-`ForOp` yielded values and their distributed
2096 // types.
2097 for (auto [i, v] :
2098 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2099 newWarpOpYieldValues.push_back(v);
2100 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
2101 }
2102 // Create the new `WarpOp` with the updated yield values and types.
2103 SmallVector<size_t> newIndices;
2104 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2105 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
2106
2107 // Next, we create a new `ForOp` with the init args yielded by the new
2108 // `WarpOp`.
2109 const unsigned initArgsStartIdx = 3; // After loop bounds.
2110 const unsigned escapingValuesStartIdx =
2111 initArgsStartIdx +
2112 forOp.getInitArgs().size(); // `ForOp` init args are positioned before
2113 // escaping values in the new `WarpOp`.
2114 SmallVector<Value> newForOpOperands;
2115 for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
2116 newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
2117
2118 // Create a new `ForOp` outside the new `WarpOp` region.
2119 OpBuilder::InsertionGuard g(rewriter);
2120 rewriter.setInsertionPointAfter(newWarpOp);
2121 auto newForOp = scf::ForOp::create(
2122 rewriter, forOp.getLoc(),
2123 /**LowerBound=**/ newWarpOp.getResult(newIndices[0]),
2124 /**UpperBound=**/ newWarpOp.getResult(newIndices[1]),
2125 /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands,
2126 /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
2127 // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
2128 // newly created `ForOp`. This `WarpOp` will contain all ops that were
2129 // contained within the original `ForOp` body.
2130 rewriter.setInsertionPointToStart(newForOp.getBody());
2131
2132 SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
2133 newForOp.getRegionIterArgs().end());
2134 SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
2135 forOp.getResultTypes().end());
2136 // Escaping values are forwarded to the inner `WarpOp` as its (additional)
2137 // arguments. We keep track of the mapping between these values and their
2138 // argument index in the inner `WarpOp` (to replace users later).
2139 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
2140 for (size_t i = escapingValuesStartIdx;
2141 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2142 innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
2143 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
2144 innerWarpInputType.size();
2145 innerWarpInputType.push_back(
2146 escapingValueInputTypes[i - escapingValuesStartIdx]);
2147 }
2148 // Create the inner `WarpOp` with the new input values and types.
2149 auto innerWarp = WarpExecuteOnLane0Op::create(
2150 rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
2151 newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
2152 innerWarpInputType);
2153
2154 // Inline the `ForOp` body into the inner `WarpOp` body.
2155 SmallVector<Value> argMapping;
2156 argMapping.push_back(newForOp.getInductionVar());
2157 for (Value args : innerWarp.getBody()->getArguments())
2158 argMapping.push_back(args);
2159
2160 argMapping.resize(forOp.getBody()->getNumArguments());
2161 SmallVector<Value> yieldOperands;
2162 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
2163 yieldOperands.push_back(operand);
2164
2165 rewriter.eraseOp(forOp.getBody()->getTerminator());
2166 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
2167
2168 // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
2169 // original `ForOp` results.
2170 rewriter.setInsertionPointToEnd(innerWarp.getBody());
2171 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
2172 rewriter.setInsertionPointAfter(innerWarp);
2173 // Insert a scf.yield op at the end of the new `ForOp` body that yields
2174 // the inner `WarpOp` results.
2175 if (!innerWarp.getResults().empty())
2176 scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
2177
2178 // Update the users of the new `WarpOp` results that were coming from the
2179 // original `ForOp` to the corresponding new `ForOp` result.
2180 for (auto [origIdx, newIdx] : forResultMapping)
2181 rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
2182 newForOp.getResult(newIdx), newForOp);
2183 // Update any users of escaping values that were forwarded to the
2184 // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
2185 newForOp.walk([&](Operation *op) {
2186 for (OpOperand &operand : op->getOpOperands()) {
2187 auto it = argIndexMapping.find(operand.get());
2188 if (it == argIndexMapping.end())
2189 continue;
2190 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
2191 }
2192 });
2193
2194 // Finally, hoist out any now uniform code from the inner `WarpOp`.
2195 mlir::vector::moveScalarUniformCode(innerWarp);
2196 return success();
2197 }
2198
2199private:
2200 DistributionMapFn distributionMapFn;
2201};
2202
2203/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
2204/// The vector is reduced in parallel. Currently limited to vector size
2205/// matching the warpOp size. E.g.:
2206/// ```
2207/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
2208/// %0 = "some_def"() : () -> (vector<32xf32>)
2209/// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
2210/// gpu.yield %1 : f32
2211/// }
2212/// ```
2213/// is lowered to:
2214/// ```
2215/// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
2216/// %1 = "some_def"() : () -> (vector<32xf32>)
2217/// gpu.yield %1 : vector<32xf32>
2218/// }
2219/// %a = vector.extract %0[0] : f32 from vector<1xf32>
2220/// %r = ("warp.reduction %a")
2221/// ```
2222struct WarpOpReduction : public WarpDistributionPattern {
2223 WarpOpReduction(MLIRContext *context,
2224 DistributedReductionFn distributedReductionFn,
2225 PatternBenefit benefit = 1)
2226 : WarpDistributionPattern(context, benefit),
2227 distributedReductionFn(std::move(distributedReductionFn)) {}
2228
2229 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2230 PatternRewriter &rewriter) const override {
2231 OpOperand *yieldOperand =
2232 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
2233 if (!yieldOperand)
2234 return failure();
2235
2236 auto reductionOp =
2237 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
2238 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
2239 // Only rank 1 vectors supported.
2240 if (vectorType.getRank() != 1)
2241 return rewriter.notifyMatchFailure(
2242 warpOp, "Only rank 1 reductions can be distributed.");
2243 // Only warp_size-sized vectors supported.
2244 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2245 return rewriter.notifyMatchFailure(
2246 warpOp, "Reduction vector dimension must match was size.");
2247 if (!reductionOp.getType().isIntOrFloat())
2248 return rewriter.notifyMatchFailure(
2249 warpOp, "Reduction distribution currently only supports floats and "
2250 "integer types.");
2251
2252 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2253 // Return vector that will be reduced from the WarpExecuteOnLane0Op.
2254 unsigned operandIndex = yieldOperand->getOperandNumber();
2255 SmallVector<Value> yieldValues = {reductionOp.getVector()};
2256 SmallVector<Type> retTypes = {
2257 VectorType::get({numElements}, reductionOp.getType())};
2258 if (reductionOp.getAcc()) {
2259 yieldValues.push_back(reductionOp.getAcc());
2260 retTypes.push_back(reductionOp.getAcc().getType());
2261 }
2262 SmallVector<size_t> newRetIndices;
2263 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2264 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2265 rewriter.setInsertionPointAfter(newWarpOp);
2266
2267 // Obtain data to reduce for a single lane.
2268 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2269 // Distribute and reduce across threads.
2270 Value fullReduce =
2271 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2272 reductionOp.getKind(), newWarpOp.getWarpSize());
2273 if (reductionOp.getAcc()) {
2274 fullReduce = vector::makeArithReduction(
2275 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2276 newWarpOp.getResult(newRetIndices[1]));
2277 }
2278 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
2279 return success();
2280 }
2281
2282private:
2283 DistributedReductionFn distributedReductionFn;
2284};
2285
2286} // namespace
2287
2293
2294void mlir::vector::populateDistributeTransferWriteOpPatterns(
2295 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2296 unsigned maxNumElementsToExtract, PatternBenefit benefit) {
2297 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
2298 maxNumElementsToExtract, benefit);
2299}
2300
2301void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2302 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2303 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
2304 PatternBenefit readBenefit) {
2305 patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
2306 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2307 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
2308 WarpOpConstant, WarpOpInsertScalar, WarpOpInsert,
2309 WarpOpCreateMask<vector::CreateMaskOp>,
2310 WarpOpCreateMask<vector::ConstantMaskOp>,
2311 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2312 patterns.getContext(), benefit);
2313 patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
2314 benefit);
2315 patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
2316 benefit);
2317 patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
2318 benefit);
2319}
2320
2321void mlir::vector::populateDistributeReduction(
2323 const DistributedReductionFn &distributedReductionFn,
2324 PatternBenefit benefit) {
2325 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
2326 benefit);
2327}
2328
2329/// Helper to know if an op can be hoisted out of the region.
2330static bool canBeHoisted(Operation *op,
2331 function_ref<bool(Value)> definedOutside) {
2332 return llvm::all_of(op->getOperands(), definedOutside) &&
2333 isMemoryEffectFree(op) && op->getNumRegions() == 0;
2334}
2335
2336void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2337 Block *body = warpOp.getBody();
2338
2339 // Keep track of the ops we want to hoist.
2340 llvm::SmallSetVector<Operation *, 8> opsToMove;
2341
2342 // Helper to check if a value is or will be defined outside of the region.
2343 auto isDefinedOutsideOfBody = [&](Value value) {
2344 auto *definingOp = value.getDefiningOp();
2345 return (definingOp && opsToMove.count(definingOp)) ||
2346 warpOp.isDefinedOutsideOfRegion(value);
2347 };
2348
2349 // Do not use walk here, as we do not want to go into nested regions and hoist
2350 // operations from there.
2351 for (auto &op : body->without_terminator()) {
2352 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
2353 return isa<VectorType>(result.getType());
2354 });
2355 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
2356 opsToMove.insert(&op);
2357 }
2358
2359 // Move all the ops marked as uniform outside of the region.
2360 for (Operation *op : opsToMove)
2361 op->moveBefore(warpOp);
2362}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
unsigned getNumOperands()
Definition Operation.h:346
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:362
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.