MLIR 23.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1//===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
18#include "mlir/IR/AffineExpr.h"
19#include "mlir/IR/Attributes.h"
23#include "llvm/ADT/SetVector.h"
24#include "llvm/ADT/SmallBitVector.h"
25#include "llvm/ADT/SmallVectorExtras.h"
26#include "llvm/Support/FormatVariadic.h"
27#include <utility>
28
29using namespace mlir;
30using namespace mlir::vector;
31using namespace mlir::gpu;
32
33/// Currently the distribution map is implicit based on the vector shape. In the
34/// future it will be part of the op.
35/// Example:
36/// ```
37/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
38/// ...
39/// gpu.yield %3 : vector<32x16x64xf32>
40/// }
41/// ```
42/// Would have an implicit map of:
43/// `(d0, d1, d2) -> (d0, d2)`
44static AffineMap calculateImplicitMap(VectorType sequentialType,
45 VectorType distributedType) {
47 perm.reserve(1);
48 // Check which dimensions of the sequential type are different than the
49 // dimensions of the distributed type to know the distributed dimensions. Then
50 // associate each distributed dimension to an ID in order.
51 for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
52 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
53 perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
54 }
55 auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
56 distributedType.getContext());
57 return map;
58}
59
60/// Given a sequential and distributed vector type, returns the distributed
61/// dimension. This function expects that only a single dimension is
62/// distributed.
63static int getDistributedDim(VectorType sequentialType,
64 VectorType distributedType) {
65 assert(sequentialType.getRank() == distributedType.getRank() &&
66 "sequential and distributed vector types must have the same rank");
67 int64_t distributedDim = -1;
68 for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
69 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
70 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
71 // support distributing multiple dimensions in the future.
72 assert(distributedDim == -1 && "found multiple distributed dims");
73 distributedDim = i;
74 }
75 }
76 return distributedDim;
77}
78
79namespace {
80
81/// Helper struct to create the load / store operations that permit transit
82/// through the parallel / sequential and the sequential / parallel boundaries
83/// when performing `rewriteWarpOpToScfFor`.
84///
85/// The vector distribution dimension is inferred from the vector types.
86struct DistributedLoadStoreHelper {
87 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
88 Value laneId, Value zero)
89 : sequentialVal(sequentialVal), distributedVal(distributedVal),
90 laneId(laneId), zero(zero) {
91 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
92 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
93 if (sequentialVectorType && distributedVectorType)
94 distributionMap =
95 calculateImplicitMap(sequentialVectorType, distributedVectorType);
96 }
97
98 Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
99 int64_t distributedSize = distributedVectorType.getDimSize(index);
100 AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
101 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
102 ArrayRef<Value>{laneId});
103 }
104
105 /// Create a store during the process of distributing the
106 /// `vector.warp_execute_on_thread_0` op.
107 /// Vector distribution assumes the following convention regarding the
108 /// temporary buffers that are created to transition values. This **must**
109 /// be properly specified in the `options.warpAllocationFn`:
110 /// 1. scalars of type T transit through a memref<1xT>.
111 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
112 Operation *buildStore(RewriterBase &b, Location loc, Value val,
113 Value buffer) {
114 assert((val == distributedVal || val == sequentialVal) &&
115 "Must store either the preregistered distributed or the "
116 "preregistered sequential value.");
117 // Scalar case can directly use memref.store.
118 if (!isa<VectorType>(val.getType()))
119 return memref::StoreOp::create(b, loc, val, buffer, zero);
120
121 // Vector case must use vector::TransferWriteOp which will later lower to
122 // vector.store of memref.store depending on further lowerings.
123 int64_t rank = sequentialVectorType.getRank();
124 SmallVector<Value> indices(rank, zero);
125 if (val == distributedVal) {
126 for (auto dimExpr : distributionMap.getResults()) {
127 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
128 indices[index] = buildDistributedOffset(b, loc, index);
129 }
130 }
131 SmallVector<bool> inBounds(indices.size(), true);
132 return vector::TransferWriteOp::create(
133 b, loc, val, buffer, indices,
134 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
135 }
136
137 /// Create a load during the process of distributing the
138 /// `vector.warp_execute_on_thread_0` op.
139 /// Vector distribution assumes the following convention regarding the
140 /// temporary buffers that are created to transition values. This **must**
141 /// be properly specified in the `options.warpAllocationFn`:
142 /// 1. scalars of type T transit through a memref<1xT>.
143 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
144 ///
145 /// When broadcastMode is true, the load is not distributed to account for
146 /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
147 ///
148 /// Example:
149 ///
150 /// ```
151 /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
152 /// gpu.yield %cst : f32
153 /// }
154 /// // Both types are f32. The constant %cst is broadcasted to all lanes.
155 /// ```
156 /// This behavior described in more detail in the documentation of the op.
157 Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
158
159 // Scalar case can directly use memref.store.
160 if (!isa<VectorType>(type))
161 return memref::LoadOp::create(b, loc, buffer, zero);
162
163 // Other cases must be vector atm.
164 // Vector case must use vector::TransferReadOp which will later lower to
165 // vector.read of memref.read depending on further lowerings.
166 assert((type == distributedVectorType || type == sequentialVectorType) &&
167 "Must store either the preregistered distributed or the "
168 "preregistered sequential type.");
169 SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
170 if (type == distributedVectorType) {
171 for (auto dimExpr : distributionMap.getResults()) {
172 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
173 indices[index] = buildDistributedOffset(b, loc, index);
174 }
175 }
176 SmallVector<bool> inBounds(indices.size(), true);
177 return vector::TransferReadOp::create(
178 b, loc, cast<VectorType>(type), buffer, indices,
179 /*padding=*/std::nullopt,
180 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
181 }
182
183 Value sequentialVal, distributedVal, laneId, zero;
184 VectorType sequentialVectorType, distributedVectorType;
185 AffineMap distributionMap;
186};
187
188} // namespace
189
190// Clones `op` into a new operation that takes `operands` and returns
191// `resultTypes`.
193 Location loc, Operation *op,
194 ArrayRef<Value> operands,
195 ArrayRef<Type> resultTypes) {
196 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
197 op->getAttrs());
198 return rewriter.create(res);
199}
200
201namespace {
202
203/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
204/// thread `laneId` executes the entirety of the computation.
205///
206/// After the transformation:
207/// - the IR within the scf.if op can be thought of as executing sequentially
208/// (from the point of view of threads along `laneId`).
209/// - the IR outside of the scf.if op can be thought of as executing in
210/// parallel (from the point of view of threads along `laneId`).
211///
212/// Values that need to transit through the parallel / sequential and the
213/// sequential / parallel boundaries do so via reads and writes to a temporary
214/// memory location.
215///
216/// The transformation proceeds in multiple steps:
217/// 1. Create the scf.if op.
218/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
219/// within the scf.if to transit the values captured from above.
220/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
221/// consistent within the scf.if.
222/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
223/// 5. Insert appropriate writes within scf.if and reads after the scf.if to
224/// transit the values returned by the op.
225/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
226/// consistent after the scf.if.
227/// 7. Perform late cleanups.
228///
229/// All this assumes the vector distribution occurs along the most minor
230/// distributed vector dimension.
231struct WarpOpToScfIfPattern : public WarpDistributionPattern {
232 WarpOpToScfIfPattern(MLIRContext *context,
233 const WarpExecuteOnLane0LoweringOptions &options,
234 PatternBenefit benefit = 1)
235 : WarpDistributionPattern(context, benefit), options(options) {}
236
237 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
238 PatternRewriter &rewriter) const override {
239 assert(warpOp.getBodyRegion().hasOneBlock() &&
240 "expected WarpOp with single block");
241 Block *warpOpBody = &warpOp.getBodyRegion().front();
242 Location loc = warpOp.getLoc();
243
244 // Passed all checks. Start rewriting.
245 OpBuilder::InsertionGuard g(rewriter);
246 rewriter.setInsertionPoint(warpOp);
247
248 // Step 1: Create scf.if op.
249 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
250 Value isLane0 = arith::CmpIOp::create(
251 rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
252 auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
253 /*withElseRegion=*/false);
254 rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
255
256 // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
257 // reads within the scf.if to transit the values captured from above.
258 SmallVector<Value> bbArgReplacements;
259 for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
260 Value sequentialVal = warpOpBody->getArgument(it.index());
261 Value distributedVal = it.value();
262 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
263 warpOp.getLaneid(), c0);
264
265 // Create buffer before the ifOp.
266 rewriter.setInsertionPoint(ifOp);
267 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
268 sequentialVal.getType());
269 // Store distributed vector into buffer, before the ifOp.
270 helper.buildStore(rewriter, loc, distributedVal, buffer);
271 // Load sequential vector from buffer, inside the ifOp.
272 rewriter.setInsertionPointToStart(ifOp.thenBlock());
273 bbArgReplacements.push_back(
274 helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
275 }
276
277 // Step 3. Insert sync after all the stores and before all the loads.
278 if (!warpOp.getArgs().empty()) {
279 rewriter.setInsertionPoint(ifOp);
280 options.warpSynchronizationFn(loc, rewriter, warpOp);
281 }
282
283 // Step 4. Move body of warpOp to ifOp.
284 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
285
286 // Step 5. Insert appropriate writes within scf.if and reads after the
287 // scf.if to transit the values returned by the op.
288 // TODO: at this point, we can reuse the shared memory from previous
289 // buffers.
290 SmallVector<Value> replacements;
291 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
292 Location yieldLoc = yieldOp.getLoc();
293 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
294 Value sequentialVal = it.value();
295 Value distributedVal = warpOp->getResult(it.index());
296 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
297 warpOp.getLaneid(), c0);
298
299 // Create buffer before the ifOp.
300 rewriter.setInsertionPoint(ifOp);
301 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
302 sequentialVal.getType());
303
304 // Store yielded value into buffer, inside the ifOp, before the
305 // terminator.
306 rewriter.setInsertionPoint(yieldOp);
307 helper.buildStore(rewriter, loc, sequentialVal, buffer);
308
309 // Load distributed value from buffer, after the warpOp.
310 rewriter.setInsertionPointAfter(ifOp);
311 // Result type and yielded value type are the same. This is a broadcast.
312 // E.g.:
313 // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
314 // gpu.yield %cst : f32
315 // }
316 // Both types are f32. The constant %cst is broadcasted to all lanes.
317 // This is described in more detail in the documentation of the op.
318 replacements.push_back(
319 helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
320 }
321
322 // Step 6. Insert sync after all the stores and before all the loads.
323 if (!yieldOp.getOperands().empty()) {
324 rewriter.setInsertionPointAfter(ifOp);
325 options.warpSynchronizationFn(loc, rewriter, warpOp);
326 }
327
328 // Step 7. Delete terminator and add empty scf.yield.
329 rewriter.eraseOp(yieldOp);
330 rewriter.setInsertionPointToEnd(ifOp.thenBlock());
331 scf::YieldOp::create(rewriter, yieldLoc);
332
333 // Compute replacements for WarpOp results.
334 rewriter.replaceOp(warpOp, replacements);
335
336 return success();
337 }
338
339private:
340 const WarpExecuteOnLane0LoweringOptions &options;
341};
342
343/// Return the distributed vector type based on the original type and the
344/// distribution map. The map is expected to have a dimension equal to the
345/// original type rank and should be a projection where the results are the
346/// distributed dimensions. If the number of results is zero there is no
347/// distribution (i.e. original type is returned).
348/// Otherwise, The number of results should be equal to the number
349/// of warp sizes which is currently limited to 1.
350/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
351/// and a warp size of 16 would distribute the second dimension (associated to
352/// d1) and return vector<16x2x64>
353static VectorType getDistributedType(VectorType originalType, AffineMap map,
354 int64_t warpSize) {
355 // If the map has zero results, return the original type.
356 if (map.getNumResults() == 0)
357 return originalType;
358 SmallVector<int64_t> targetShape(originalType.getShape());
359 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
360 unsigned position = map.getDimPosition(i);
361 if (targetShape[position] % warpSize != 0) {
362 if (warpSize % targetShape[position] != 0) {
363 return VectorType();
364 }
365 warpSize /= targetShape[position];
366 targetShape[position] = 1;
367 continue;
368 }
369 targetShape[position] = targetShape[position] / warpSize;
370 warpSize = 1;
371 break;
372 }
373 if (warpSize != 1) {
374 return VectorType();
375 }
376 VectorType targetType =
377 VectorType::get(targetShape, originalType.getElementType());
378 return targetType;
379}
380
381/// Given a warpOp that contains ops with regions, the corresponding op's
382/// "inner" region and the distributionMapFn, get all values used by the op's
383/// region that are defined within the warpOp, but outside the inner region.
384/// Return the set of values, their types and their distributed types.
385std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
387getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
388 DistributionMapFn distributionMapFn) {
389 llvm::SmallSetVector<Value, 32> escapingValues;
390 SmallVector<Type> escapingValueTypes;
391 SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
392 if (innerRegion.empty())
393 return {std::move(escapingValues), std::move(escapingValueTypes),
394 std::move(escapingValueDistTypes)};
395 mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
396 Operation *parent = operand->get().getParentRegion()->getParentOp();
397 if (warpOp->isAncestor(parent)) {
398 if (!escapingValues.insert(operand->get()))
399 return;
400 Type distType = operand->get().getType();
401 if (auto vecType = dyn_cast<VectorType>(distType)) {
402 AffineMap map = distributionMapFn(operand->get());
403 distType = getDistributedType(vecType, map,
404 map.isEmpty() ? 1 : warpOp.getWarpSize());
405 }
406 escapingValueTypes.push_back(operand->get().getType());
407 escapingValueDistTypes.push_back(distType);
408 }
409 });
410 return {std::move(escapingValues), std::move(escapingValueTypes),
411 std::move(escapingValueDistTypes)};
412}
413
414/// Distribute transfer_write ops based on the affine map returned by
415/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
416/// will not be distributed (it should be less than the warp size).
417///
418/// Example:
419/// ```
420/// %0 = gpu.warp_execute_on_lane_0(%id){
421/// ...
422/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
423/// gpu.yield
424/// }
425/// ```
426/// To
427/// ```
428/// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
429/// ...
430/// gpu.yield %v : vector<32xf32>
431/// }
432/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
433struct WarpOpTransferWrite : public WarpDistributionPattern {
434 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
435 unsigned maxNumElementsToExtract, PatternBenefit b = 1)
436 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
437 maxNumElementsToExtract(maxNumElementsToExtract) {}
438
439 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
440 /// are multiples of the distribution ratio are supported at the moment.
441 LogicalResult tryDistributeOp(RewriterBase &rewriter,
442 vector::TransferWriteOp writeOp,
443 WarpExecuteOnLane0Op warpOp) const {
444 VectorType writtenVectorType = writeOp.getVectorType();
445
446 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
447 // to separate it from the rest.
448 if (writtenVectorType.getRank() == 0)
449 return failure();
450
451 // 2. Compute the distributed type.
452 AffineMap map = distributionMapFn(writeOp.getVector());
453 VectorType targetType =
454 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
455 if (!targetType)
456 return failure();
457
458 // 2.5 Compute the distributed type for the new mask;
459 VectorType maskType;
460 if (writeOp.getMask()) {
461 // TODO: Distribution of masked writes with non-trivial permutation maps
462 // requires the distribution of the mask to elementwise match the
463 // distribution of the permuted written vector. Currently the details
464 // of which lane is responsible for which element is captured strictly
465 // by shape information on the warp op, and thus requires materializing
466 // the permutation in IR.
467 if (!writeOp.getPermutationMap().isMinorIdentity())
468 return failure();
469 maskType =
470 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
471 }
472
473 // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
474 // the rest.
475 vector::TransferWriteOp newWriteOp =
476 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
477
478 // 4. Reindex the write using the distribution map.
479 auto newWarpOp =
480 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
481
482 // Delinearize the lane id based on the way threads are divided across the
483 // vector. To get the number of threads per vector dimension, divide the
484 // sequential size by the distributed size along each dim.
485 rewriter.setInsertionPoint(newWriteOp);
486 SmallVector<OpFoldResult> delinearizedIdSizes;
487 for (auto [seqSize, distSize] :
488 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
489 assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
490 delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
491 }
492 SmallVector<Value> delinearized;
493 if (map.getNumResults() > 1) {
494 delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
495 rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
496 delinearizedIdSizes)
497 .getResults();
498 } else {
499 // If there is only one map result, we can elide the delinearization
500 // op and use the lane id directly.
501 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
502 }
503
504 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
505 Location loc = newWriteOp.getLoc();
506 SmallVector<Value> indices(newWriteOp.getIndices().begin(),
507 newWriteOp.getIndices().end());
508 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
509 AffineExpr d0, d1;
510 bindDims(newWarpOp.getContext(), d0, d1);
511 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
512 if (!indexExpr)
513 continue;
514 unsigned indexPos = indexExpr.getPosition();
515 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
516 Value laneId = delinearized[vectorPos];
517 auto scale =
518 rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
520 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
521 }
522 newWriteOp.getIndicesMutable().assign(indices);
523
524 return success();
525 }
526
527 /// Extract TransferWriteOps of vector<1x> into a separate warp op.
528 LogicalResult tryExtractOp(RewriterBase &rewriter,
529 vector::TransferWriteOp writeOp,
530 WarpExecuteOnLane0Op warpOp) const {
531 Location loc = writeOp.getLoc();
532 VectorType vecType = writeOp.getVectorType();
533
534 if (vecType.getNumElements() > maxNumElementsToExtract) {
535 return rewriter.notifyMatchFailure(
536 warpOp,
537 llvm::formatv(
538 "writes more elements ({0}) than allowed to extract ({1})",
539 vecType.getNumElements(), maxNumElementsToExtract));
540 }
541
542 // Do not process warp ops that contain only TransferWriteOps.
543 if (llvm::all_of(warpOp.getOps(),
544 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
545 return failure();
546
547 SmallVector<Value> yieldValues = {writeOp.getVector()};
548 SmallVector<Type> retTypes = {vecType};
549 SmallVector<size_t> newRetIndices;
550 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
551 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
552 rewriter.setInsertionPointAfter(newWarpOp);
553
554 // Create a second warp op that contains only writeOp.
555 auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc, TypeRange(),
556 newWarpOp.getLaneid(),
557 newWarpOp.getWarpSize());
558 Block &body = secondWarpOp.getBodyRegion().front();
559 rewriter.setInsertionPointToStart(&body);
560 auto newWriteOp =
561 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
562 newWriteOp.getValueToStoreMutable().assign(
563 newWarpOp.getResult(newRetIndices[0]));
564 rewriter.eraseOp(writeOp);
565 gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
566 return success();
567 }
568
569 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
570 PatternRewriter &rewriter) const override {
571 gpu::YieldOp yield = warpOp.getTerminator();
572 Operation *lastNode = yield->getPrevNode();
573 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
574 if (!writeOp)
575 return failure();
576
577 Value maybeMask = writeOp.getMask();
578 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
579 return writeOp.getVector() == value ||
580 (maybeMask && maybeMask == value) ||
581 warpOp.isDefinedOutsideOfRegion(value);
582 }))
583 return failure();
584
585 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
586 return success();
587
588 // Masked writes not supported for extraction.
589 if (writeOp.getMask())
590 return failure();
591
592 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
593 return success();
594
595 return failure();
596 }
597
598private:
599 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
600 /// execute op with the proper return type. The new write op is updated to
601 /// write the result of the new warp execute op. The old `writeOp` is deleted.
602 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
603 WarpExecuteOnLane0Op warpOp,
604 vector::TransferWriteOp writeOp,
605 VectorType targetType,
606 VectorType maybeMaskType) const {
607 assert(writeOp->getParentOp() == warpOp &&
608 "write must be nested immediately under warp");
609 OpBuilder::InsertionGuard g(rewriter);
610 SmallVector<size_t> newRetIndices;
611 WarpExecuteOnLane0Op newWarpOp;
612 if (maybeMaskType) {
614 rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
615 TypeRange{targetType, maybeMaskType}, newRetIndices);
616 } else {
618 rewriter, warpOp, ValueRange{{writeOp.getVector()}},
619 TypeRange{targetType}, newRetIndices);
620 }
621 rewriter.setInsertionPointAfter(newWarpOp);
622 auto newWriteOp =
623 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
624 rewriter.eraseOp(writeOp);
625 newWriteOp.getValueToStoreMutable().assign(
626 newWarpOp.getResult(newRetIndices[0]));
627 if (maybeMaskType)
628 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
629 return newWriteOp;
630 }
631
632 DistributionMapFn distributionMapFn;
633 unsigned maxNumElementsToExtract = 1;
634};
635
636/// Sink out elementwise op feeding into a warp op yield.
637/// ```
638/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
639/// ...
640/// %3 = arith.addf %1, %2 : vector<32xf32>
641/// gpu.yield %3 : vector<32xf32>
642/// }
643/// ```
644/// To
645/// ```
646/// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
647/// vector<1xf32>, vector<1xf32>) {
648/// ...
649/// %4 = arith.addf %2, %3 : vector<32xf32>
650/// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
651/// vector<32xf32>
652/// }
653/// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
654struct WarpOpElementwise : public WarpDistributionPattern {
655 using Base::Base;
656 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
657 PatternRewriter &rewriter) const override {
658 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
660 });
661 if (!yieldOperand)
662 return failure();
663
664 Operation *elementWise = yieldOperand->get().getDefiningOp();
665 unsigned operandIndex = yieldOperand->getOperandNumber();
666 Value distributedVal = warpOp.getResult(operandIndex);
667 SmallVector<Value> yieldValues;
668 SmallVector<Type> retTypes;
669 Location loc = warpOp.getLoc();
670 for (OpOperand &operand : elementWise->getOpOperands()) {
671 Type targetType;
672 if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
673 // If the result type is a vector, the operands must also be vectors.
674 auto operandType = cast<VectorType>(operand.get().getType());
675 targetType =
676 VectorType::get(vecType.getShape(), operandType.getElementType());
677 } else {
678 auto operandType = operand.get().getType();
679 assert(!isa<VectorType>(operandType) &&
680 "unexpected yield of vector from op with scalar result type");
681 targetType = operandType;
682 }
683 retTypes.push_back(targetType);
684 yieldValues.push_back(operand.get());
685 }
686 SmallVector<size_t> newRetIndices;
687 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
688 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
689 rewriter.setInsertionPointAfter(newWarpOp);
690 SmallVector<Value> newOperands(elementWise->getOperands().begin(),
691 elementWise->getOperands().end());
692 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
693 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
694 }
695 OpBuilder::InsertionGuard g(rewriter);
696 rewriter.setInsertionPointAfter(newWarpOp);
697 Operation *newOp = cloneOpWithOperandsAndTypes(
698 rewriter, loc, elementWise, newOperands,
699 {newWarpOp.getResult(operandIndex).getType()});
700 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
701 newOp->getResult(0));
702 return success();
703 }
704};
705
706/// Sink out splat constant op feeding into a warp op yield.
707/// ```
708/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
709/// ...
710/// %cst = arith.constant dense<2.0> : vector<32xf32>
711/// gpu.yield %cst : vector<32xf32>
712/// }
713/// ```
714/// To
715/// ```
716/// gpu.warp_execute_on_lane_0(%arg0 {
717/// ...
718/// }
719/// %0 = arith.constant dense<2.0> : vector<1xf32>
720struct WarpOpConstant : public WarpDistributionPattern {
721 using Base::Base;
722 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
723 PatternRewriter &rewriter) const override {
724 OpOperand *yieldOperand =
725 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
726 if (!yieldOperand)
727 return failure();
728 auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
729 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
730 if (!dense)
731 return failure();
732 // Notify the rewriter that the warp op is changing (see the comment on
733 // the WarpOpTransferRead pattern).
734 rewriter.startOpModification(warpOp);
735 unsigned operandIndex = yieldOperand->getOperandNumber();
736 Attribute scalarAttr = dense.getSplatValue<Attribute>();
737 auto newAttr = DenseElementsAttr::get(
738 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
739 Location loc = warpOp.getLoc();
740 rewriter.setInsertionPointAfter(warpOp);
741 Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
742 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
743 rewriter.finalizeOpModification(warpOp);
744 return success();
745 }
746};
747
748/// Sink out step op feeding into a warp op yield.
749/// Vector step op is treated similar to arith.constant, apart from
750/// the result that represents a sequence [0, vec_size).
751/// Due to the to vec_size == warp_size limitation,
752/// we can simply wrap the lane id into a vector (i.e., broadcast).
753/// Supporting vec_size != warp_size may involve preserving the step
754/// result and using additional arith ops (the exact details are TBD).
755/// ```
756/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
757/// ...
758/// %cst = vector.step : vector<32xindex>
759/// gpu.yield %cst : vector<1xindex>
760/// }
761/// ```
762/// To
763/// ```
764/// gpu.warp_execute_on_lane_0(%arg0) {
765/// ...
766/// }
767/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
768struct WarpOpStep final : public WarpDistributionPattern {
769 using Base::Base;
770 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
771 PatternRewriter &rewriter) const override {
772 OpOperand *yieldOperand =
773 getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
774 if (!yieldOperand)
775 return failure();
776 const unsigned operandIdx = yieldOperand->getOperandNumber();
777 auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
778 VectorType resTy = stepOp.getResult().getType();
779 if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize()))
780 return rewriter.notifyMatchFailure(
781 warpOp,
782 llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
783 resTy.getNumElements(), warpOp.getWarpSize()));
784 VectorType newVecTy =
785 cast<VectorType>(warpOp.getResult(operandIdx).getType());
786 rewriter.setInsertionPointAfter(warpOp);
787 Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
788 newVecTy, warpOp.getLaneid());
789 rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
790 return success();
791 }
792};
793
794/// Sink out transfer_read op feeding into a warp op yield.
795/// ```
796/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
797/// ...
798// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
799// vector<32xf32>
800/// gpu.yield %2 : vector<32xf32>
801/// }
802/// ```
803/// To
804/// ```
805/// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
806/// vector<1xf32>, vector<1xf32>) {
807/// ...
808/// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
809/// vector<32xf32> gpu.yield %2 : vector<32xf32>
810/// }
811/// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
812struct WarpOpTransferRead : public WarpDistributionPattern {
813 using Base::Base;
814 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
815 PatternRewriter &rewriter) const override {
816 // Try to find a distributable yielded read. Note that this pattern can
817 // still fail at the end after distribution, in which case this might have
818 // missed another distributable read.
819 OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
820 // Don't duplicate transfer_read ops when distributing.
821 return isa<vector::TransferReadOp>(op) && op->hasOneUse();
822 });
823 if (!operand)
824 return rewriter.notifyMatchFailure(
825 warpOp, "warp result is not a vector.transfer_read op");
826 auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
827
828 // Source must be defined outside of the region.
829 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
830 return rewriter.notifyMatchFailure(
831 read, "source must be defined outside of the region");
832
833 unsigned operandIndex = operand->getOperandNumber();
834 Value distributedVal = warpOp.getResult(operandIndex);
835
836 SmallVector<Value, 4> indices(read.getIndices().begin(),
837 read.getIndices().end());
838 auto sequentialType = cast<VectorType>(read.getResult().getType());
839 auto distributedType = cast<VectorType>(distributedVal.getType());
840 AffineMap map = calculateImplicitMap(sequentialType, distributedType);
841 AffineMap indexMap = map.compose(read.getPermutationMap());
842
843 // Try to delinearize the lane ID to match the rank expected for
844 // distribution.
845 SmallVector<Value> delinearizedIds;
846 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
847 distributedType.getShape(), warpOp.getWarpSize(),
848 warpOp.getLaneid(), delinearizedIds)) {
849 return rewriter.notifyMatchFailure(
850 read, "cannot delinearize lane ID for distribution");
851 }
852 assert(!delinearizedIds.empty() || map.getNumResults() == 0);
853
854 // Distribute indices and the mask (if present).
855 OpBuilder::InsertionGuard g(rewriter);
856 SmallVector<Value> additionalResults(indices.begin(), indices.end());
857 SmallVector<Type> additionalResultTypes(indices.size(),
858 rewriter.getIndexType());
859 additionalResults.push_back(read.getPadding());
860 additionalResultTypes.push_back(read.getPadding().getType());
861
862 bool hasMask = false;
863 if (read.getMask()) {
864 hasMask = true;
865 // TODO: Distribution of masked reads with non-trivial permutation maps
866 // requires the distribution of the mask to elementwise match the
867 // distribution of the permuted written vector. Currently the details
868 // of which lane is responsible for which element is captured strictly
869 // by shape information on the warp op, and thus requires materializing
870 // the permutation in IR.
871 if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
872 return rewriter.notifyMatchFailure(
873 read, "non-trivial permutation maps not supported");
874 VectorType maskType =
875 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
876 additionalResults.push_back(read.getMask());
877 additionalResultTypes.push_back(maskType);
878 }
879
880 SmallVector<size_t> newRetIndices;
881 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
882 rewriter, warpOp, additionalResults, additionalResultTypes,
883 newRetIndices);
884 distributedVal = newWarpOp.getResult(operandIndex);
885
886 // Distributed indices were appended first.
887 SmallVector<Value> newIndices;
888 for (int64_t i = 0, e = indices.size(); i < e; ++i)
889 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
890
891 rewriter.setInsertionPointAfter(newWarpOp);
892 for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
893 AffineExpr d0, d1;
894 bindDims(read.getContext(), d0, d1);
895 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
896 if (!indexExpr)
897 continue;
898 unsigned indexPos = indexExpr.getPosition();
899 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
900 int64_t scale = distributedType.getDimSize(vectorPos);
901 newIndices[indexPos] = affine::makeComposedAffineApply(
902 rewriter, read.getLoc(), d0 + scale * d1,
903 {newIndices[indexPos], delinearizedIds[vectorPos]});
904 }
905
906 // Distributed padding value was appended right after the indices.
907 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
908 // Distributed mask value was added at the end (if the op has a mask).
909 Value newMask =
910 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
911 : Value();
912 auto newRead = vector::TransferReadOp::create(
913 rewriter, read.getLoc(), distributedVal.getType(), read.getBase(),
914 newIndices, read.getPermutationMapAttr(), newPadding, newMask,
915 read.getInBoundsAttr());
916
917 rewriter.replaceAllUsesWith(distributedVal, newRead);
918 return success();
919 }
920};
921
922/// Remove any result that has no use along with the matching yieldOp operand.
923// TODO: Move this in WarpExecuteOnLane0Op canonicalization.
924struct WarpOpDeadResult : public WarpDistributionPattern {
925 using Base::Base;
926 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
927 PatternRewriter &rewriter) const override {
928 SmallVector<Type> newResultTypes;
929 newResultTypes.reserve(warpOp->getNumResults());
930 SmallVector<Value> newYieldValues;
931 newYieldValues.reserve(warpOp->getNumResults());
932 DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
933 DenseMap<OpResult, int64_t> dedupResultPositionMap;
934 gpu::YieldOp yield = warpOp.getTerminator();
935
936 // Some values may be yielded multiple times and correspond to multiple
937 // results. Deduplicating occurs by taking each result with its matching
938 // yielded value, and:
939 // 1. recording the unique first position at which the value with uses is
940 // yielded.
941 // 2. recording for the result, the first position at which the dedup'ed
942 // value is yielded.
943 // 3. skipping from the new result types / new yielded values any result
944 // that has no use or whose yielded value has already been seen.
945 for (OpResult result : warpOp.getResults()) {
946 if (result.use_empty())
947 continue;
948 Value yieldOperand = yield.getOperand(result.getResultNumber());
949 auto it = dedupYieldOperandPositionMap.insert(
950 std::make_pair(yieldOperand, newResultTypes.size()));
951 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
952 if (!it.second)
953 continue;
954 newResultTypes.push_back(result.getType());
955 newYieldValues.push_back(yieldOperand);
956 }
957 // No modification, exit early.
958 if (yield.getNumOperands() == newYieldValues.size())
959 return failure();
960 // Move the body of the old warpOp to a new warpOp.
961 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
962 rewriter, warpOp, newYieldValues, newResultTypes);
963
964 // Simplify the new warp op after dropping dead results.
965 newWarpOp.getBody()->walk([&](Operation *op) {
966 if (isOpTriviallyDead(op))
967 rewriter.eraseOp(op);
968 });
969
970 // Replace results of the old warpOp by the new, deduplicated results.
971 SmallVector<Value> newValues;
972 newValues.reserve(warpOp->getNumResults());
973 for (OpResult result : warpOp.getResults()) {
974 if (result.use_empty())
975 newValues.push_back(Value());
976 else
977 newValues.push_back(
978 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
979 }
980 rewriter.replaceOp(warpOp, newValues);
981 return success();
982 }
983};
984
985// If an operand is directly yielded out of the region we can forward it
986// directly and it doesn't need to go through the region.
987struct WarpOpForwardOperand : public WarpDistributionPattern {
988 using Base::Base;
989 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
990 PatternRewriter &rewriter) const override {
991 gpu::YieldOp yield = warpOp.getTerminator();
992 Value valForwarded;
993 unsigned resultIndex;
994 for (OpOperand &operand : yield->getOpOperands()) {
995 Value result = warpOp.getResult(operand.getOperandNumber());
996 if (result.use_empty())
997 continue;
998
999 // Assume all the values coming from above are uniform.
1000 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
1001 if (result.getType() != operand.get().getType())
1002 continue;
1003 valForwarded = operand.get();
1004 resultIndex = operand.getOperandNumber();
1005 break;
1006 }
1007 auto arg = dyn_cast<BlockArgument>(operand.get());
1008 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1009 continue;
1010 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1011 if (result.getType() != warpOperand.getType())
1012 continue;
1013 valForwarded = warpOperand;
1014 resultIndex = operand.getOperandNumber();
1015 break;
1016 }
1017 if (!valForwarded)
1018 return failure();
1019 // Notify the rewriter that the warp op is changing (see the comment on
1020 // the WarpOpTransferRead pattern).
1021 rewriter.startOpModification(warpOp);
1022 rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
1023 rewriter.finalizeOpModification(warpOp);
1024 return success();
1025 }
1026};
1027
1028struct WarpOpBroadcast : public WarpDistributionPattern {
1029 using Base::Base;
1030 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1031 PatternRewriter &rewriter) const override {
1032 OpOperand *operand =
1033 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1034 if (!operand)
1035 return failure();
1036 unsigned int operandNumber = operand->getOperandNumber();
1037 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
1038 Location loc = broadcastOp.getLoc();
1039 auto destVecType =
1040 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1041 Value broadcastSrc = broadcastOp.getSource();
1042 Type broadcastSrcType = broadcastSrc.getType();
1043
1044 // Check that the broadcast actually spans a set of values uniformly across
1045 // all threads. In other words, check that each thread can reconstruct
1046 // their own broadcast.
1047 // For that we simply check that the broadcast we want to build makes sense.
1048 if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
1049 vector::BroadcastableToResult::Success)
1050 return failure();
1051 SmallVector<size_t> newRetIndices;
1052 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1053 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1054 rewriter.setInsertionPointAfter(newWarpOp);
1055 Value broadcasted = vector::BroadcastOp::create(
1056 rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1057 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1058 broadcasted);
1059 return success();
1060 }
1061};
1062
1063/// Pattern to move shape cast out of the warp op. shape cast is basically a
1064/// no-op for warp distribution; we need to handle the shape though.
1065struct WarpOpShapeCast : public WarpDistributionPattern {
1066 using Base::Base;
1067 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1068 PatternRewriter &rewriter) const override {
1069 OpOperand *operand =
1070 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1071 if (!operand)
1072 return failure();
1073
1074 auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
1075
1076 unsigned int operandNumber = operand->getOperandNumber();
1077 auto castDistributedType =
1078 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1079 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1080 VectorType castResultType = castDistributedType;
1081
1082 FailureOr<VectorType> maybeSrcType =
1083 inferDistributedSrcType(castDistributedType, castOriginalType);
1084 if (failed(maybeSrcType))
1085 return failure();
1086 castDistributedType = *maybeSrcType;
1087
1088 SmallVector<size_t> newRetIndices;
1089 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1090 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1091 newRetIndices);
1092 rewriter.setInsertionPointAfter(newWarpOp);
1093 Value newCast = vector::ShapeCastOp::create(
1094 rewriter, oldCastOp.getLoc(), castResultType,
1095 newWarpOp->getResult(newRetIndices[0]));
1096 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1097 return success();
1098 }
1099
1100private:
1101 static FailureOr<VectorType>
1102 inferDistributedSrcType(VectorType distributedType, VectorType srcType) {
1103 unsigned distributedRank = distributedType.getRank();
1104 unsigned srcRank = srcType.getRank();
1105 if (distributedRank == srcRank)
1106 // Nothing to do.
1107 return distributedType;
1108 if (distributedRank < srcRank) {
1109 // If the distributed type has a smaller rank than the original type,
1110 // prepend with unit dimensions to make the types the same length.
1111 SmallVector<int64_t> shape(srcRank - distributedRank, 1);
1112 llvm::append_range(shape, distributedType.getShape());
1113 return VectorType::get(shape, distributedType.getElementType());
1114 }
1115 // Handle the expanding shape_cast's.
1116 //
1117 // If the casted-from type has one rank, we can assert that the element
1118 // count in that rank will match the full thread-level element count of
1119 // the yielded type.
1120 // Note that getNumElements() will correctly "flatten" the shape of the
1121 // specific shape_cast's distributed type (its distribution may be
1122 // different from the overall warp size, e.g. if the cast is applied to
1123 // a result of a gather).
1124 if (srcRank == 1)
1125 return VectorType::get(distributedType.getNumElements(),
1126 srcType.getElementType());
1127 // Try to strip leading unit dimensions to match the ranks. We bail out
1128 // for more complex tile sizes, because those would require us to
1129 // determine the specific distribution parameters to threads, which is
1130 // unfeasible within this pattern.
1131 unsigned excessDims = distributedRank - srcRank;
1132 ArrayRef<int64_t> shape = distributedType.getShape();
1133 if (!llvm::all_of(shape.take_front(excessDims),
1134 [](int64_t d) { return d == 1; }))
1135 return failure();
1136 return VectorType::get(shape.drop_front(excessDims),
1137 distributedType.getElementType());
1138 }
1139};
1140
1141/// Sink out vector.create_mask / vector.constant_mask op feeding into a warp op
1142/// yield.
1143/// ```
1144/// %0 = ...
1145/// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1146/// ...
1147/// %mask = vector.create_mask %0 : vector<32xi1>
1148/// // or %mask = vector.constant_mask[2] : vector<32xi1>
1149/// gpu.yield %mask : vector<32xi1>
1150/// }
1151/// ```
1152/// To
1153/// ```
1154/// %0 = ...
1155/// gpu.warp_execute_on_lane_0(%arg0) {
1156/// ...
1157/// }
1158/// %cmp = arith.cmpi ult, %laneid, %0
1159/// %ub = arith.select %cmp, %c0, %c1
1160/// %1 = vector.create_mask %ub : vector<1xi1>
1161template <typename OpType,
1162 typename = std::enable_if_t<llvm::is_one_of<
1163 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
1164struct WarpOpCreateMask : public WarpDistributionPattern {
1165 using Base::Base;
1166 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1167 PatternRewriter &rewriter) const override {
1168 OpOperand *yieldOperand = getWarpResult(warpOp, (llvm::IsaPred<OpType>));
1169 if (!yieldOperand)
1170 return failure();
1171
1172 Operation *mask = yieldOperand->get().getDefiningOp<OpType>();
1173
1174 // Early exit if any values needed for calculating the new mask indices
1175 // are defined inside the warp op.
1176 if (mask->getOperands().size() &&
1177 !llvm::all_of(mask->getOperands(), [&](Value value) {
1178 return warpOp.isDefinedOutsideOfRegion(value);
1179 }))
1180 return failure();
1181
1182 Location loc = mask->getLoc();
1183 unsigned operandIndex = yieldOperand->getOperandNumber();
1184
1185 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1186 VectorType seqType = cast<VectorType>(mask->getResult(0).getType());
1187 ArrayRef<int64_t> seqShape = seqType.getShape();
1188 ArrayRef<int64_t> distShape = distType.getShape();
1189 SmallVector<Value> materializedOperands;
1190 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
1191 materializedOperands.append(mask->getOperands().begin(),
1192 mask->getOperands().end());
1193 } else {
1194 auto constantMaskOp = cast<vector::ConstantMaskOp>(mask);
1195 auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
1196 for (auto dimSize : dimSizes)
1197 materializedOperands.push_back(
1198 arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
1199 }
1200
1201 rewriter.setInsertionPointAfter(warpOp);
1202
1203 // Delinearize the lane ID for constructing the distributed mask sizes.
1204 SmallVector<Value> delinearizedIds;
1205 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1206 warpOp.getWarpSize(), warpOp.getLaneid(),
1207 delinearizedIds))
1208 return rewriter.notifyMatchFailure(
1209 mask, "cannot delinearize lane ID for distribution");
1210 assert(!delinearizedIds.empty());
1211
1212 // Notify the rewriter that the warp op is changing (see the comment on
1213 // the WarpOpTransferRead pattern).
1214 rewriter.startOpModification(warpOp);
1215
1216 AffineExpr s0, s1;
1217 bindSymbols(rewriter.getContext(), s0, s1);
1218 SmallVector<Value> newOperands;
1219 for (int i = 0, e = distShape.size(); i < e; ++i) {
1220 // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1221 // find the distance from the largest mask index owned by this lane to the
1222 // original mask size. `vector.create_mask` implicitly clamps mask
1223 // operands to the range [0, mask_vector_size[i]], or in other words, the
1224 // mask sizes are always in the range [0, mask_vector_size[i]).
1225 Value maskDimIdx = affine::makeComposedAffineApply(
1226 rewriter, loc, s1 - s0 * distShape[i],
1227 {delinearizedIds[i], materializedOperands[i]});
1228 newOperands.push_back(maskDimIdx);
1229 }
1230
1231 auto newMask =
1232 vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1233 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1234 rewriter.finalizeOpModification(warpOp);
1235 return success();
1236 }
1237};
1238
1239/// Sink out insert_strided_slice op feeding into a warp op yield.
1240/// ```
1241/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
1242/// ...
1243/// %src = ... : vector<4x32xf32>
1244/// %dest = ... : vector<8x32xf32>
1245/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1246/// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
1247/// gpu.yield %insert : vector<8x32xf32>
1248/// }
1249/// ```
1250/// To
1251/// ```
1252/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
1253/// vector<8x1xf32>) {
1254/// ...
1255/// %src = ... : vector<4x32xf32>
1256/// %dest = ... : vector<8x32xf32>
1257/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
1258/// }
1259/// %insert = vector.insert_strided_slice %0#0, %0#1,
1260/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
1261/// ```
1262/// NOTE: Current support assumes that both src and dest vectors are distributed
1263/// to lanes and sinking the insert op does not require any cross lane
1264/// communication.
1265struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
1266 using Base::Base;
1267 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1268 PatternRewriter &rewriter) const override {
1269 OpOperand *operand =
1270 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1271 if (!operand)
1272 return failure();
1273 unsigned int operandNumber = operand->getOperandNumber();
1274 auto insertOp =
1275 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1276 auto distributedType =
1277 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1278 // Distributed type must be 2D or higher.
1279 // TODO: Support 1D distributed types.
1280 if (distributedType.getRank() < 2)
1281 return rewriter.notifyMatchFailure(
1282 insertOp, "result vector type must be 2D or higher");
1283 // Find the distributed dimension of the dest vector. There should be
1284 // exactly one.
1285 auto yieldedType = cast<VectorType>(operand->get().getType());
1286 int64_t destDistributedDim =
1287 getDistributedDim(yieldedType, distributedType);
1288 assert(destDistributedDim != -1 && "could not find distributed dimension");
1289
1290 VectorType srcType = insertOp.getSourceVectorType();
1291 VectorType destType = insertOp.getDestVectorType();
1292 // Currently we require that both source (kD) and dest (nD) vectors are
1293 // distributed. This requires that distributedDim (d) is contained in the
1294 // last k dims of the dest vector (d >= n - k).
1295 // TODO: Add support for case where source vector is not distributed.
1296 int64_t sourceDistributedDim =
1297 destDistributedDim - (destType.getRank() - srcType.getRank());
1298 if (sourceDistributedDim < 0)
1299 return rewriter.notifyMatchFailure(
1300 insertOp,
1301 "distributed dimension must be in the last k dims of dest vector");
1302 // Distributed dimension must be fully inserted.
1303 if (srcType.getDimSize(sourceDistributedDim) !=
1304 destType.getDimSize(destDistributedDim))
1305 return rewriter.notifyMatchFailure(
1306 insertOp, "distributed dimension must be fully inserted");
1307 SmallVector<int64_t> newSourceDistShape(
1308 insertOp.getSourceVectorType().getShape());
1309 newSourceDistShape[sourceDistributedDim] =
1310 distributedType.getDimSize(destDistributedDim);
1311 auto newSourceTy =
1312 VectorType::get(newSourceDistShape, distributedType.getElementType());
1313 VectorType newDestTy = distributedType;
1314 SmallVector<size_t> newRetIndices;
1315 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1316 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1317 {newSourceTy, newDestTy}, newRetIndices);
1318 rewriter.setInsertionPointAfter(newWarpOp);
1319 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1320 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1321 // Create a new insert strided slice op that inserts distributed source into
1322 // distributed dest.
1323 Value newInsert = vector::InsertStridedSliceOp::create(
1324 rewriter, insertOp.getLoc(), distributedDest.getType(),
1325 distributedSource, distributedDest, insertOp.getOffsets(),
1326 insertOp.getStrides());
1327 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
1328 return success();
1329 }
1330};
1331
1332/// Sink out extract_strided_slice op feeding into a warp op yield.
1333/// ```
1334/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
1335/// ...
1336/// %src = ... : vector<64x32xf32>
1337/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1338/// strides = [1] : vector<64x32xf32> to vector<16x32xf32>
1339/// gpu.yield %extract : vector<16x32xf32>
1340/// }
1341/// ```
1342/// To
1343/// ```
1344/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
1345/// ...
1346/// %src = ... : vector<64x32xf32>
1347/// gpu.yield %src : vector<64x32xf32>
1348/// }
1349/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1350/// strides = [1] : vector<64x1xf32> to vector<16x1xf32>
1351/// ```
1352/// NOTE: Current support assumes that the extraction happens only on non
1353/// distributed dimensions (does not require cross lane communication).
1354struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
1355 using Base::Base;
1356 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1357 PatternRewriter &rewriter) const override {
1358 OpOperand *operand =
1359 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1360 if (!operand)
1361 return failure();
1362 unsigned int operandNumber = operand->getOperandNumber();
1363 auto extractOp =
1364 operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
1365 auto distributedType =
1366 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1367 // Distributed type must be 2D or higher.
1368 // TODO: Support 1D distributed types.
1369 if (distributedType.getRank() < 2)
1370 return rewriter.notifyMatchFailure(
1371 extractOp, "result vector type must be 2D or higher");
1372
1373 // Find the distributed dimension. There should be exactly one.
1374 auto yieldedType = cast<VectorType>(operand->get().getType());
1375 int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1376 assert(distributedDim != -1 && "could not find distributed dimension");
1377
1378 int64_t numOfExtractedDims =
1379 static_cast<int64_t>(extractOp.getSizes().size());
1380 // If the distributed dim is included in the extracted dims, then we make
1381 // sure distributed dim is fully extracted. If distributed dim is not
1382 // included in extracted dims, it is guaranteed to be fully extracted (i.e.
1383 // distributed dim comes after all the extracted dims)
1384 // TODO: Partial extraction from distributed dimension require cross lane
1385 // communication.
1386 if (distributedDim < numOfExtractedDims) {
1387 int64_t distributedDimOffset =
1388 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1389 .getInt();
1390 int64_t distributedDimSize =
1391 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1392 .getInt();
1393 if (distributedDimOffset != 0 ||
1394 distributedDimSize != yieldedType.getDimSize(distributedDim))
1395 return rewriter.notifyMatchFailure(
1396 extractOp, "distributed dimension must be fully extracted");
1397 }
1398 SmallVector<int64_t> newDistributedShape(
1399 extractOp.getSourceVectorType().getShape());
1400 newDistributedShape[distributedDim] =
1401 distributedType.getDimSize(distributedDim);
1402 auto newDistributedType =
1403 VectorType::get(newDistributedShape, distributedType.getElementType());
1404 SmallVector<size_t> newRetIndices;
1405 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1406 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1407 newRetIndices);
1408 rewriter.setInsertionPointAfter(newWarpOp);
1409 SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1410 extractOp.getSizes(), [](Attribute attr) { return attr; });
1411 // Update the distributed sizes to match the distributed type.
1412 if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
1413 distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1414 distributedType.getDimSize(distributedDim));
1415
1416 // Create a new extract strided slice op that extracts from the
1417 // distributed vector.
1418 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1419 Value newExtract = vector::ExtractStridedSliceOp::create(
1420 rewriter, extractOp.getLoc(), distributedType, distributedVec,
1421 extractOp.getOffsets(),
1422 ArrayAttr::get(rewriter.getContext(), distributedSizes),
1423 extractOp.getStrides());
1424 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1425 newExtract);
1426 return success();
1427 }
1428};
1429
1430/// Pattern to move out vector.extract of single element vector. Those don't
1431/// need to be distributed and can just be propagated outside of the region.
1432struct WarpOpExtract : public WarpDistributionPattern {
1433 using Base::Base;
1434 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1435 PatternRewriter &rewriter) const override {
1436 OpOperand *operand =
1437 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1438 if (!operand)
1439 return failure();
1440 unsigned int operandNumber = operand->getOperandNumber();
1441 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1442 VectorType extractSrcType = extractOp.getSourceVectorType();
1443 Location loc = extractOp.getLoc();
1444
1445 // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1446 if (extractSrcType.getRank() <= 1) {
1447 return failure();
1448 }
1449
1450 // All following cases are 2d or higher dimensional source vectors.
1451
1452 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1453 // There is no distribution, this is a broadcast. Simply move the extract
1454 // out of the warp op.
1455 // TODO: This could be optimized. E.g., in case of a scalar result, let
1456 // one lane extract and shuffle the result to all other lanes (same as
1457 // the 1d case).
1458 SmallVector<size_t> newRetIndices;
1459 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1460 rewriter, warpOp, {extractOp.getSource()},
1461 {extractOp.getSourceVectorType()}, newRetIndices);
1462 rewriter.setInsertionPointAfter(newWarpOp);
1463 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1464 // Extract from distributed vector.
1465 Value newExtract = vector::ExtractOp::create(
1466 rewriter, loc, distributedVec, extractOp.getMixedPosition());
1467 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1468 newExtract);
1469 return success();
1470 }
1471
1472 // Find the distributed dimension. There should be exactly one.
1473 auto distributedType =
1474 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1475 auto yieldedType = cast<VectorType>(operand->get().getType());
1476 int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1477 assert(distributedDim != -1 && "could not find distributed dimension");
1478 (void)distributedDim;
1479
1480 // Yield source vector from warp op.
1481 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1482 for (int i = 0; i < distributedType.getRank(); ++i)
1483 newDistributedShape[i + extractOp.getNumIndices()] =
1484 distributedType.getDimSize(i);
1485 auto newDistributedType =
1486 VectorType::get(newDistributedShape, distributedType.getElementType());
1487 SmallVector<size_t> newRetIndices;
1488 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1489 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1490 newRetIndices);
1491 rewriter.setInsertionPointAfter(newWarpOp);
1492 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1493 // Extract from distributed vector.
1494 Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1495 extractOp.getMixedPosition());
1496 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1497 newExtract);
1498 return success();
1499 }
1500};
1501
1502/// Pattern to move out vector.extract with a scalar result.
1503/// Only supports 1-D and 0-D sources for now.
1504struct WarpOpExtractScalar : public WarpDistributionPattern {
1505 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1506 PatternBenefit b = 1)
1507 : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1508 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1509 PatternRewriter &rewriter) const override {
1510 OpOperand *operand =
1511 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1512 if (!operand)
1513 return failure();
1514 unsigned int operandNumber = operand->getOperandNumber();
1515 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1516 VectorType extractSrcType = extractOp.getSourceVectorType();
1517 // Only supports 1-D or 0-D sources for now.
1518 if (extractSrcType.getRank() > 1) {
1519 return rewriter.notifyMatchFailure(
1520 extractOp, "only 0-D or 1-D source supported for now");
1521 }
1522 // TODO: Supported shuffle types should be parameterizable, similar to
1523 // `WarpShuffleFromIdxFn`.
1524 if (!extractSrcType.getElementType().isF32() &&
1525 !extractSrcType.getElementType().isInteger(32))
1526 return rewriter.notifyMatchFailure(
1527 extractOp, "only f32/i32 element types are supported");
1528 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1529 Type elType = extractSrcType.getElementType();
1530 VectorType distributedVecType;
1531 if (!is0dOrVec1Extract) {
1532 assert(extractSrcType.getRank() == 1 &&
1533 "expected that extract src rank is 0 or 1");
1534 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1535 return failure();
1536 int64_t elementsPerLane =
1537 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1538 distributedVecType = VectorType::get({elementsPerLane}, elType);
1539 } else {
1540 distributedVecType = extractSrcType;
1541 }
1542 // Yield source vector and position (if present) from warp op.
1543 SmallVector<Value> additionalResults{extractOp.getSource()};
1544 SmallVector<Type> additionalResultTypes{distributedVecType};
1545 additionalResults.append(
1546 SmallVector<Value>(extractOp.getDynamicPosition()));
1547 additionalResultTypes.append(
1548 SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1549
1550 Location loc = extractOp.getLoc();
1551 SmallVector<size_t> newRetIndices;
1552 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1553 rewriter, warpOp, additionalResults, additionalResultTypes,
1554 newRetIndices);
1555 rewriter.setInsertionPointAfter(newWarpOp);
1556 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1557
1558 // 0d extract: The new warp op broadcasts the source vector to all lanes.
1559 // All lanes extract the scalar.
1560 if (is0dOrVec1Extract) {
1561 Value newExtract;
1562 SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1563 newExtract =
1564 vector::ExtractOp::create(rewriter, loc, distributedVec, indices);
1565 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1566 newExtract);
1567 return success();
1568 }
1569
1570 int64_t staticPos = extractOp.getStaticPosition()[0];
1571 OpFoldResult pos = ShapedType::isDynamic(staticPos)
1572 ? (newWarpOp->getResult(newRetIndices[1]))
1573 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1574 // 1d extract: Distribute the source vector. One lane extracts and shuffles
1575 // the value to all other lanes.
1576 int64_t elementsPerLane = distributedVecType.getShape()[0];
1577 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1578 // tid of extracting thread: pos / elementsPerLane
1579 Value broadcastFromTid = affine::makeComposedAffineApply(
1580 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1581 // Extract at position: pos % elementsPerLane
1582 Value newPos =
1583 elementsPerLane == 1
1584 ? arith::ConstantIndexOp::create(rewriter, loc, 0).getResult()
1585 : affine::makeComposedAffineApply(rewriter, loc,
1586 sym0 % elementsPerLane, pos);
1587 Value extracted =
1588 vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1589
1590 // Shuffle the extracted value to all lanes.
1591 Value shuffled = warpShuffleFromIdxFn(
1592 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1593 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1594 return success();
1595 }
1596
1597private:
1598 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1599};
1600
1601/// Pattern to move out vector.insert with a scalar input.
1602/// Only supports 1-D and 0-D destinations for now.
1603struct WarpOpInsertScalar : public WarpDistributionPattern {
1604 using Base::Base;
1605 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1606 PatternRewriter &rewriter) const override {
1607 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1608 if (!operand)
1609 return failure();
1610 unsigned int operandNumber = operand->getOperandNumber();
1611 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1612 VectorType vecType = insertOp.getDestVectorType();
1613 VectorType distrType =
1614 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1615
1616 // Only supports 1-D or 0-D destinations for now.
1617 if (vecType.getRank() > 1) {
1618 return rewriter.notifyMatchFailure(
1619 insertOp, "only 0-D or 1-D source supported for now");
1620 }
1621
1622 // Yield destination vector, source scalar and position from warp op.
1623 SmallVector<Value> additionalResults{insertOp.getDest(),
1624 insertOp.getValueToStore()};
1625 SmallVector<Type> additionalResultTypes{
1626 distrType, insertOp.getValueToStore().getType()};
1627 additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1628 additionalResultTypes.append(
1629 SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1630
1631 Location loc = insertOp.getLoc();
1632 SmallVector<size_t> newRetIndices;
1633 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1634 rewriter, warpOp, additionalResults, additionalResultTypes,
1635 newRetIndices);
1636 rewriter.setInsertionPointAfter(newWarpOp);
1637 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1638 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1639 rewriter.setInsertionPointAfter(newWarpOp);
1640
1641 OpFoldResult pos;
1642 if (vecType.getRank() != 0) {
1643 int64_t staticPos = insertOp.getStaticPosition()[0];
1644 pos = ShapedType::isDynamic(staticPos)
1645 ? (newWarpOp->getResult(newRetIndices[2]))
1646 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1647 }
1648
1649 // This condition is always true for 0-d vectors.
1650 if (vecType == distrType) {
1651 Value newInsert;
1652 SmallVector<OpFoldResult> indices;
1653 if (pos) {
1654 indices.push_back(pos);
1655 }
1656 newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1657 distributedVec, indices);
1658 // Broadcast: Simply move the vector.insert op out.
1659 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1660 newInsert);
1661 return success();
1662 }
1663
1664 // This is a distribution. Only one lane should insert.
1665 int64_t elementsPerLane = distrType.getShape()[0];
1666 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1667 // tid of extracting thread: pos / elementsPerLane
1668 Value insertingLane = affine::makeComposedAffineApply(
1669 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1670 // Insert position: pos % elementsPerLane
1671 OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
1672 rewriter, loc, sym0 % elementsPerLane, pos);
1673 Value isInsertingLane =
1674 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1675 newWarpOp.getLaneid(), insertingLane);
1676 Value newResult =
1677 scf::IfOp::create(
1678 rewriter, loc, isInsertingLane,
1679 /*thenBuilder=*/
1680 [&](OpBuilder &builder, Location loc) {
1681 Value newInsert = vector::InsertOp::create(
1682 builder, loc, newSource, distributedVec, newPos);
1683 scf::YieldOp::create(builder, loc, newInsert);
1684 },
1685 /*elseBuilder=*/
1686 [&](OpBuilder &builder, Location loc) {
1687 scf::YieldOp::create(builder, loc, distributedVec);
1688 })
1689 .getResult(0);
1690 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1691 return success();
1692 }
1693};
1694
1695struct WarpOpInsert : public WarpDistributionPattern {
1696 using Base::Base;
1697 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1698 PatternRewriter &rewriter) const override {
1699 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1700 if (!operand)
1701 return failure();
1702 unsigned int operandNumber = operand->getOperandNumber();
1703 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1704 Location loc = insertOp.getLoc();
1705
1706 // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1707 if (insertOp.getDestVectorType().getRank() <= 1) {
1708 return failure();
1709 }
1710
1711 // All following cases are 2d or higher dimensional source vectors.
1712
1713 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1714 // There is no distribution, this is a broadcast. Simply move the insert
1715 // out of the warp op.
1716 SmallVector<size_t> newRetIndices;
1717 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1718 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1719 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1720 newRetIndices);
1721 rewriter.setInsertionPointAfter(newWarpOp);
1722 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1723 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1724 Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1725 distributedDest,
1726 insertOp.getMixedPosition());
1727 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1728 newResult);
1729 return success();
1730 }
1731
1732 // Find the distributed dimension. There should be exactly one.
1733 auto distrDestType =
1734 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1735 auto yieldedType = cast<VectorType>(operand->get().getType());
1736 int64_t distrDestDim = -1;
1737 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1738 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1739 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1740 // support distributing multiple dimensions in the future.
1741 assert(distrDestDim == -1 && "found multiple distributed dims");
1742 distrDestDim = i;
1743 }
1744 }
1745 assert(distrDestDim != -1 && "could not find distributed dimension");
1746
1747 // Compute the distributed source vector type.
1748 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1749 SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1750 // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1751 // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1752 // insert a smaller vector<3xf32>.
1753 // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1754 // case, one lane will insert the source vector<96xf32>. The other
1755 // lanes will not do anything.
1756 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1757 if (distrSrcDim >= 0)
1758 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1759 auto distrSrcType =
1760 VectorType::get(distrSrcShape, distrDestType.getElementType());
1761
1762 // Yield source and dest vectors from warp op.
1763 SmallVector<size_t> newRetIndices;
1764 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1765 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1766 {distrSrcType, distrDestType}, newRetIndices);
1767 rewriter.setInsertionPointAfter(newWarpOp);
1768 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1769 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1770
1771 // Insert into the distributed vector.
1772 Value newResult;
1773 if (distrSrcDim >= 0) {
1774 // Every lane inserts a small piece.
1775 newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1776 distributedDest,
1777 insertOp.getMixedPosition());
1778 } else {
1779 // One lane inserts the entire source vector.
1780 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1781 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1782 SmallVector<int64_t> newPos = getAsIntegers(pos);
1783 // tid of inserting lane: pos / elementsPerLane
1784 Value insertingLane = arith::ConstantIndexOp::create(
1785 rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1786 Value isInsertingLane =
1787 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1788 newWarpOp.getLaneid(), insertingLane);
1789 // Insert position: pos % elementsPerLane
1790 newPos[distrDestDim] %= elementsPerLane;
1791 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1792 Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1793 distributedDest, newPos);
1794 scf::YieldOp::create(builder, loc, newInsert);
1795 };
1796 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1797 scf::YieldOp::create(builder, loc, distributedDest);
1798 };
1799 newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1800 /*thenBuilder=*/insertingBuilder,
1801 /*elseBuilder=*/nonInsertingBuilder)
1802 .getResult(0);
1803 }
1804
1805 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1806 return success();
1807 }
1808};
1809
1810/// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
1811/// the scf.if is the last operation in the region so that it doesn't
1812/// change the order of execution. This creates a new scf.if after the
1813/// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
1814/// the "inner" WarpExecuteOnLane0Op. Example:
1815/// ```
1816/// gpu.warp_execute_on_lane_0(%laneid)[32] {
1817/// %payload = ... : vector<32xindex>
1818/// scf.if %pred {
1819/// vector.store %payload, %buffer[%idx] : memref<128xindex>,
1820/// vector<32xindex>
1821/// }
1822/// gpu.yield
1823/// }
1824/// ```
1825/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
1826/// %payload = ... : vector<32xindex>
1827/// gpu.yield %payload : vector<32xindex>
1828/// }
1829/// scf.if %pred {
1830/// gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
1831/// ^bb0(%arg1: vector<32xindex>):
1832/// vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
1833/// }
1834/// }
1835/// ```
1836struct WarpOpScfIfOp : public WarpDistributionPattern {
1837 WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1838 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1839 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1840 PatternRewriter &rewriter) const override {
1841 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1842 // Only pick up `IfOp` if it is the last op in the region.
1843 Operation *lastNode = warpOpYield->getPrevNode();
1844 auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1845 if (!ifOp)
1846 return failure();
1847
1848 // The current `WarpOp` can yield two types of values:
1849 // 1. Not results of `IfOp`:
1850 // Preserve them in the new `WarpOp`.
1851 // Collect their yield index to remap the usages.
1852 // 2. Results of `IfOp`:
1853 // They are not part of the new `WarpOp` results.
1854 // Map current warp's yield operand index to `IfOp` result idx.
1855 SmallVector<Value> nonIfYieldValues;
1856 SmallVector<unsigned> nonIfYieldIndices;
1857 llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1858 llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1859 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1860 const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
1861 if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
1862 nonIfYieldValues.push_back(yieldOperand.get());
1863 nonIfYieldIndices.push_back(yieldOperandIdx);
1864 continue;
1865 }
1866 OpResult ifResult = cast<OpResult>(yieldOperand.get());
1867 const unsigned ifResultIdx = ifResult.getResultNumber();
1868 ifResultMapping[yieldOperandIdx] = ifResultIdx;
1869 // If this `ifOp` result is vector type and it is yielded by the
1870 // `WarpOp`, we keep track the distributed type for this result.
1871 if (!isa<VectorType>(ifResult.getType()))
1872 continue;
1873 VectorType distType =
1874 cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1875 ifResultDistTypes[ifResultIdx] = distType;
1876 }
1877
1878 // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
1879 // them
1880 auto [escapingValuesThen, escapingValueInputTypesThen,
1881 escapingValueDistTypesThen] =
1882 getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1883 distributionMapFn);
1884 auto [escapingValuesElse, escapingValueInputTypesElse,
1885 escapingValueDistTypesElse] =
1886 getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1887 distributionMapFn);
1888 if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1889 llvm::is_contained(escapingValueDistTypesElse, Type{}))
1890 return failure();
1891
1892 // The new `WarpOp` groups yields values in following order:
1893 // 1. Branch condition
1894 // 2. Escaping values then branch
1895 // 3. Escaping values else branch
1896 // 4. All non-`ifOp` yielded values.
1897 SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1898 newWarpOpYieldValues.append(escapingValuesThen.begin(),
1899 escapingValuesThen.end());
1900 newWarpOpYieldValues.append(escapingValuesElse.begin(),
1901 escapingValuesElse.end());
1902 SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1903 newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1904 escapingValueDistTypesThen.end());
1905 newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1906 escapingValueDistTypesElse.end());
1907
1908 for (auto [idx, val] :
1909 llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1910 newWarpOpYieldValues.push_back(val);
1911 newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1912 }
1913 // Replace the old `WarpOp` with the new one that has additional yield
1914 // values and types.
1915 SmallVector<size_t> newIndices;
1916 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1917 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
1918 // `ifOp` returns the result of the inner warp op.
1919 SmallVector<Type> newIfOpDistResTypes;
1920 for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1921 Type distType = cast<Value>(res).getType();
1922 if (auto vecType = dyn_cast<VectorType>(distType)) {
1923 AffineMap map = distributionMapFn(cast<Value>(res));
1924 // Fallback to affine map if the dist result was not previously recorded
1925 distType = ifResultDistTypes.count(i)
1926 ? ifResultDistTypes[i]
1927 : getDistributedType(
1928 vecType, map,
1929 map.isEmpty() ? 1 : newWarpOp.getWarpSize());
1930 }
1931 newIfOpDistResTypes.push_back(distType);
1932 }
1933 // Create a new `IfOp` outside the new `WarpOp` region.
1934 OpBuilder::InsertionGuard g(rewriter);
1935 rewriter.setInsertionPointAfter(newWarpOp);
1936 auto newIfOp = scf::IfOp::create(
1937 rewriter, ifOp.getLoc(), newIfOpDistResTypes,
1938 newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
1939 static_cast<bool>(ifOp.elseBlock()));
1940 auto encloseRegionInWarpOp =
1941 [&](Block *oldIfBranch, Block *newIfBranch,
1942 llvm::SmallSetVector<Value, 32> &escapingValues,
1943 SmallVector<Type> &escapingValueInputTypes,
1944 size_t warpResRangeStart) {
1945 OpBuilder::InsertionGuard g(rewriter);
1946 if (!newIfBranch)
1947 return;
1948 rewriter.setInsertionPointToStart(newIfBranch);
1949 llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1950 SmallVector<Value> innerWarpInputVals;
1951 SmallVector<Type> innerWarpInputTypes;
1952 for (size_t i = 0; i < escapingValues.size();
1953 ++i, ++warpResRangeStart) {
1954 innerWarpInputVals.push_back(
1955 newWarpOp.getResult(newIndices[warpResRangeStart]));
1956 escapeValToBlockArgIndex[escapingValues[i]] =
1957 innerWarpInputTypes.size();
1958 innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1959 }
1960 auto innerWarp = WarpExecuteOnLane0Op::create(
1961 rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1962 newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1963 innerWarpInputVals, innerWarpInputTypes);
1964
1965 innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
1966 innerWarp.getWarpRegion().addArguments(
1967 innerWarpInputTypes,
1968 SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1969
1970 SmallVector<Value> yieldOperands;
1971 for (Value operand : oldIfBranch->getTerminator()->getOperands())
1972 yieldOperands.push_back(operand);
1973 rewriter.eraseOp(oldIfBranch->getTerminator());
1974
1975 rewriter.setInsertionPointToEnd(innerWarp.getBody());
1976 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1977 rewriter.setInsertionPointAfter(innerWarp);
1978 scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1979
1980 // Update any users of escaping values that were forwarded to the
1981 // inner `WarpOp`. These values are arguments of the inner `WarpOp`.
1982 innerWarp.walk([&](Operation *op) {
1983 SmallVector<std::pair<unsigned, Value>> replacements;
1984 for (OpOperand &operand : op->getOpOperands()) {
1985 auto it = escapeValToBlockArgIndex.find(operand.get());
1986 if (it == escapeValToBlockArgIndex.end())
1987 continue;
1988 replacements.emplace_back(
1989 operand.getOperandNumber(),
1990 innerWarp.getBodyRegion().getArgument(it->second));
1991 }
1992 if (!replacements.empty()) {
1993 rewriter.modifyOpInPlace(op, [&]() {
1994 for (auto [idx, newVal] : replacements)
1995 op->setOperand(idx, newVal);
1996 });
1997 }
1998 });
1999 mlir::vector::moveScalarUniformCode(innerWarp);
2000 };
2001 encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
2002 &newIfOp.getThenRegion().front(), escapingValuesThen,
2003 escapingValueInputTypesThen, 1);
2004 if (!ifOp.getElseRegion().empty())
2005 encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
2006 &newIfOp.getElseRegion().front(),
2007 escapingValuesElse, escapingValueInputTypesElse,
2008 1 + escapingValuesThen.size());
2009 // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
2010 // result.
2011 for (auto [origIdx, newIdx] : ifResultMapping)
2012 rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
2013 newIfOp.getResult(newIdx), newIfOp);
2014
2015 // The original `ifOp` was left inside `newWarpOp` with empty then/else
2016 // regions (their blocks were moved into the inner WarpOps by takeBody).
2017 // Clear remaining uses and erase it to restore IR validity. Directly
2018 // update newWarpOp's yield operands instead of using replaceAllUsesWith,
2019 // to avoid triggering notifyOperandReplaced on the now-invalid ifOp.
2020 {
2021 OpBuilder::InsertionGuard guard(rewriter);
2022 rewriter.setInsertionPoint(ifOp);
2023 Operation *yield = newWarpOp.getTerminator();
2024 rewriter.modifyOpInPlace(yield, [&]() {
2025 for (auto [origIdx, ifResultIdx] : ifResultMapping) {
2026 Value poison = ub::PoisonOp::create(
2027 rewriter, ifOp.getLoc(), ifOp.getResult(ifResultIdx).getType());
2028 yield->setOperand(origIdx, poison);
2029 }
2030 });
2031 rewriter.eraseOp(ifOp);
2032 }
2033
2034 return success();
2035 }
2036
2037private:
2038 DistributionMapFn distributionMapFn;
2039};
2040
2041/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
2042/// the scf.ForOp is the last operation in the region so that it doesn't
2043/// change the order of execution. This creates a new scf.for region after the
2044/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
2045/// WarpExecuteOnLane0Op region. Example:
2046/// ```
2047/// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
2048/// ...
2049/// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
2050/// -> (vector<128xf32>) {
2051/// ...
2052/// scf.yield %r : vector<128xf32>
2053/// }
2054/// gpu.yield %v1 : vector<128xf32>
2055/// }
2056/// ```
2057/// To:
2058/// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
2059/// ...
2060/// gpu.yield %v : vector<128xf32>
2061/// }
2062/// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
2063/// -> (vector<4xf32>) {
2064/// %iw = gpu.warp_execute_on_lane_0(%laneid)
2065/// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
2066/// ^bb0(%arg: vector<128xf32>):
2067/// ...
2068/// gpu.yield %ir : vector<128xf32>
2069/// }
2070/// scf.yield %iw : vector<4xf32>
2071/// }
2072/// ```
2073struct WarpOpScfForOp : public WarpDistributionPattern {
2074
2075 WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
2076 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
2077 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2078 PatternRewriter &rewriter) const override {
2079 gpu::YieldOp warpOpYield = warpOp.getTerminator();
2080 // Only pick up `ForOp` if it is the last op in the region.
2081 Operation *lastNode = warpOpYield->getPrevNode();
2082 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
2083 if (!forOp)
2084 return failure();
2085 // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
2086 // Those Values need to be returned by the new warp op.
2087 auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2088 getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2089 distributionMapFn);
2090 if (llvm::is_contained(escapingValueDistTypes, Type{}))
2091 return failure();
2092 // `WarpOp` can yield two types of values:
2093 // 1. Values that are not results of the `ForOp`:
2094 // These values must also be yielded by the new `WarpOp`. Also, we need
2095 // to record the index mapping for these values to replace them later.
2096 // 2. Values that are results of the `ForOp`:
2097 // In this case, we record the index mapping between the `WarpOp` result
2098 // index and matching `ForOp` result index.
2099 // Additionally, we keep track of the distributed types for all `ForOp`
2100 // vector results.
2101 SmallVector<Value> nonForYieldedValues;
2102 SmallVector<unsigned> nonForResultIndices;
2103 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
2104 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
2105 llvm::SmallBitVector forResultsMapped(forOp.getNumResults());
2106 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
2107 // Yielded value is not a result of the forOp.
2108 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
2109 nonForYieldedValues.push_back(yieldOperand.get());
2110 nonForResultIndices.push_back(yieldOperand.getOperandNumber());
2111 continue;
2112 }
2113 OpResult forResult = cast<OpResult>(yieldOperand.get());
2114 unsigned int forResultNumber = forResult.getResultNumber();
2115 forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber;
2116 forResultsMapped.set(forResultNumber);
2117 // If this `ForOp` result is vector type and it is yielded by the
2118 // `WarpOp`, we keep track the distributed type for this result.
2119 if (!isa<VectorType>(forResult.getType()))
2120 continue;
2121 VectorType distType = cast<VectorType>(
2122 warpOp.getResult(yieldOperand.getOperandNumber()).getType());
2123 forResultDistTypes[forResultNumber] = distType;
2124 }
2125
2126 // Newly created `WarpOp` will yield values in following order:
2127 // 1. Loop bounds.
2128 // 2. All init args of the `ForOp`.
2129 // 3. All escaping values.
2130 // 4. All non-`ForOp` yielded values.
2131 SmallVector<Value> newWarpOpYieldValues;
2132 SmallVector<Type> newWarpOpDistTypes;
2133 newWarpOpYieldValues.insert(
2134 newWarpOpYieldValues.end(),
2135 {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
2136 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2137 {forOp.getLowerBound().getType(),
2138 forOp.getUpperBound().getType(),
2139 forOp.getStep().getType()});
2140 for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
2141 newWarpOpYieldValues.push_back(initArg);
2142 // Compute the distributed type for this init arg.
2143 Type distType = initArg.getType();
2144 if (auto vecType = dyn_cast<VectorType>(distType)) {
2145 // If the `ForOp` result corresponds to this init arg is already yielded
2146 // we can get the distributed type from `forResultDistTypes` map.
2147 // Otherwise, we compute it using distributionMapFn.
2148 AffineMap map = distributionMapFn(initArg);
2149 distType =
2150 forResultDistTypes.count(i)
2151 ? forResultDistTypes[i]
2152 : getDistributedType(vecType, map,
2153 map.isEmpty() ? 1 : warpOp.getWarpSize());
2154 }
2155 newWarpOpDistTypes.push_back(distType);
2156 }
2157 // Insert escaping values and their distributed types.
2158 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
2159 escapingValues.begin(), escapingValues.end());
2160 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2161 escapingValueDistTypes.begin(),
2162 escapingValueDistTypes.end());
2163 // Next, we insert all non-`ForOp` yielded values and their distributed
2164 // types.
2165 for (auto [i, v] :
2166 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2167 newWarpOpYieldValues.push_back(v);
2168 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
2169 }
2170 // Create the new `WarpOp` with the updated yield values and types.
2171 SmallVector<size_t> newIndices;
2172 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2173 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
2174
2175 // Next, we create a new `ForOp` with the init args yielded by the new
2176 // `WarpOp`.
2177 const unsigned initArgsStartIdx = 3; // After loop bounds.
2178 const unsigned escapingValuesStartIdx =
2179 initArgsStartIdx +
2180 forOp.getInitArgs().size(); // `ForOp` init args are positioned before
2181 // escaping values in the new `WarpOp`.
2182 SmallVector<Value> newForOpOperands;
2183 for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
2184 newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
2185
2186 // Create a new `ForOp` outside the new `WarpOp` region.
2187 OpBuilder::InsertionGuard g(rewriter);
2188 rewriter.setInsertionPointAfter(newWarpOp);
2189 auto newForOp = scf::ForOp::create(
2190 rewriter, forOp.getLoc(),
2191 /**LowerBound=**/ newWarpOp.getResult(newIndices[0]),
2192 /**UpperBound=**/ newWarpOp.getResult(newIndices[1]),
2193 /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands,
2194 /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
2195 // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
2196 // newly created `ForOp`. This `WarpOp` will contain all ops that were
2197 // contained within the original `ForOp` body.
2198 rewriter.setInsertionPointToStart(newForOp.getBody());
2199
2200 SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
2201 newForOp.getRegionIterArgs().end());
2202 SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
2203 forOp.getResultTypes().end());
2204 // Escaping values are forwarded to the inner `WarpOp` as its (additional)
2205 // arguments. We keep track of the mapping between these values and their
2206 // argument index in the inner `WarpOp` (to replace users later).
2207 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
2208 for (size_t i = escapingValuesStartIdx;
2209 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2210 innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
2211 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
2212 innerWarpInputType.size();
2213 innerWarpInputType.push_back(
2214 escapingValueInputTypes[i - escapingValuesStartIdx]);
2215 }
2216 // Create the inner `WarpOp` with the new input values and types.
2217 auto innerWarp = WarpExecuteOnLane0Op::create(
2218 rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
2219 newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
2220 innerWarpInputType);
2221
2222 // Inline the `ForOp` body into the inner `WarpOp` body.
2223 SmallVector<Value> argMapping;
2224 argMapping.push_back(newForOp.getInductionVar());
2225 for (Value args : innerWarp.getBody()->getArguments())
2226 argMapping.push_back(args);
2227
2228 argMapping.resize(forOp.getBody()->getNumArguments());
2229 SmallVector<Value> yieldOperands;
2230 for (Value operand : forOp.getBody()->getTerminator()->getOperands()) {
2231 if (BlockArgument blockArg = dyn_cast<BlockArgument>(operand);
2232 blockArg && blockArg.getOwner() == forOp.getBody()) {
2233 yieldOperands.push_back(argMapping[blockArg.getArgNumber()]);
2234 continue;
2235 }
2236 yieldOperands.push_back(operand);
2237 }
2238
2239 rewriter.eraseOp(forOp.getBody()->getTerminator());
2240 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
2241
2242 // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
2243 // original `ForOp` results.
2244 rewriter.setInsertionPointToEnd(innerWarp.getBody());
2245 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
2246 rewriter.setInsertionPointAfter(innerWarp);
2247 // Insert a scf.yield op at the end of the new `ForOp` body that yields
2248 // the inner `WarpOp` results.
2249 if (!innerWarp.getResults().empty())
2250 scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
2251
2252 // Update the users of the new `WarpOp` results that were coming from the
2253 // original `ForOp` to the corresponding new `ForOp` result.
2254 for (auto [origIdx, newIdx] : forResultMapping)
2255 rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
2256 newForOp.getResult(newIdx), newForOp);
2257
2258 // The original `ForOp` was left inside `newWarpOp` with an empty body
2259 // region (its body block was moved into `innerWarp` by `mergeBlocks`).
2260 // Clear remaining uses and erase it to restore IR validity.
2261 for (OpResult result : forOp.getResults()) {
2262 if (forResultsMapped.test(result.getResultNumber()))
2263 rewriter.replaceAllUsesWith(
2264 result, forOp.getInitArgs()[result.getResultNumber()]);
2265 }
2266 rewriter.eraseOp(forOp);
2267
2268 // Update any users of escaping values that were forwarded to the
2269 // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
2270 newForOp.walk([&](Operation *op) {
2271 SmallVector<std::pair<unsigned, Value>> replacements;
2272 for (OpOperand &operand : op->getOpOperands()) {
2273 auto it = argIndexMapping.find(operand.get());
2274 if (it == argIndexMapping.end())
2275 continue;
2276 replacements.emplace_back(
2277 operand.getOperandNumber(),
2278 innerWarp.getBodyRegion().getArgument(it->second));
2279 }
2280 if (!replacements.empty()) {
2281 rewriter.modifyOpInPlace(op, [&]() {
2282 for (auto [idx, newVal] : replacements)
2283 op->setOperand(idx, newVal);
2284 });
2285 }
2286 });
2287
2288 // Finally, hoist out any now uniform code from the inner `WarpOp`.
2289 mlir::vector::moveScalarUniformCode(innerWarp);
2290 return success();
2291 }
2292
2293private:
2294 DistributionMapFn distributionMapFn;
2295};
2296
2297/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
2298/// The vector is reduced in parallel. Currently limited to vector size
2299/// matching the warpOp size. E.g.:
2300/// ```
2301/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
2302/// %0 = "some_def"() : () -> (vector<32xf32>)
2303/// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
2304/// gpu.yield %1 : f32
2305/// }
2306/// ```
2307/// is lowered to:
2308/// ```
2309/// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
2310/// %1 = "some_def"() : () -> (vector<32xf32>)
2311/// gpu.yield %1 : vector<32xf32>
2312/// }
2313/// %a = vector.extract %0[0] : f32 from vector<1xf32>
2314/// %r = ("warp.reduction %a")
2315/// ```
2316struct WarpOpReduction : public WarpDistributionPattern {
2317 WarpOpReduction(MLIRContext *context,
2318 DistributedReductionFn distributedReductionFn,
2319 PatternBenefit benefit = 1)
2320 : WarpDistributionPattern(context, benefit),
2321 distributedReductionFn(std::move(distributedReductionFn)) {}
2322
2323 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2324 PatternRewriter &rewriter) const override {
2325 OpOperand *yieldOperand =
2326 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
2327 if (!yieldOperand)
2328 return failure();
2329
2330 auto reductionOp =
2331 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
2332 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
2333 // Only rank 1 vectors supported.
2334 if (vectorType.getRank() != 1)
2335 return rewriter.notifyMatchFailure(
2336 warpOp, "Only rank 1 reductions can be distributed.");
2337 // Only warp_size-sized vectors supported.
2338 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2339 return rewriter.notifyMatchFailure(
2340 warpOp, "Reduction vector dimension must match was size.");
2341 if (!reductionOp.getType().isIntOrFloat())
2342 return rewriter.notifyMatchFailure(
2343 warpOp, "Reduction distribution currently only supports floats and "
2344 "integer types.");
2345
2346 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2347 // Return vector that will be reduced from the WarpExecuteOnLane0Op.
2348 unsigned operandIndex = yieldOperand->getOperandNumber();
2349 SmallVector<Value> yieldValues = {reductionOp.getVector()};
2350 SmallVector<Type> retTypes = {
2351 VectorType::get({numElements}, reductionOp.getType())};
2352 if (reductionOp.getAcc()) {
2353 yieldValues.push_back(reductionOp.getAcc());
2354 retTypes.push_back(reductionOp.getAcc().getType());
2355 }
2356 SmallVector<size_t> newRetIndices;
2357 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2358 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2359 rewriter.setInsertionPointAfter(newWarpOp);
2360
2361 // Obtain data to reduce for a single lane.
2362 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2363 // Distribute and reduce across threads.
2364 Value fullReduce =
2365 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2366 reductionOp.getKind(), newWarpOp.getWarpSize());
2367 if (reductionOp.getAcc()) {
2368 fullReduce = vector::makeArithReduction(
2369 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2370 newWarpOp.getResult(newRetIndices[1]));
2371 }
2372 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
2373 return success();
2374 }
2375
2376private:
2377 DistributedReductionFn distributedReductionFn;
2378};
2379
2380} // namespace
2381
2383 RewritePatternSet &patterns,
2385 patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
2386}
2387
2388void mlir::vector::populateDistributeTransferWriteOpPatterns(
2389 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2390 unsigned maxNumElementsToExtract, PatternBenefit benefit) {
2391 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
2392 maxNumElementsToExtract, benefit);
2393}
2394
2395void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2396 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2397 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
2398 PatternBenefit readBenefit) {
2399 patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
2400 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2401 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
2402 WarpOpConstant, WarpOpInsertScalar, WarpOpInsert,
2403 WarpOpCreateMask<vector::CreateMaskOp>,
2404 WarpOpCreateMask<vector::ConstantMaskOp>,
2405 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2406 patterns.getContext(), benefit);
2407 patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
2408 benefit);
2409 patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
2410 benefit);
2411 patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
2412 benefit);
2413}
2414
2415void mlir::vector::populateDistributeReduction(
2416 RewritePatternSet &patterns,
2417 const DistributedReductionFn &distributedReductionFn,
2418 PatternBenefit benefit) {
2419 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
2420 benefit);
2421}
2422
2423/// Helper to know if an op can be hoisted out of the region.
2424static bool canBeHoisted(Operation *op,
2425 function_ref<bool(Value)> definedOutside) {
2426 return llvm::all_of(op->getOperands(), definedOutside) &&
2427 isMemoryEffectFree(op) && op->getNumRegions() == 0;
2428}
2429
2430void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2431 Block *body = warpOp.getBody();
2432
2433 // Keep track of the ops we want to hoist.
2434 llvm::SmallSetVector<Operation *, 8> opsToMove;
2435
2436 // Helper to check if a value is or will be defined outside of the region.
2437 auto isDefinedOutsideOfBody = [&](Value value) {
2438 auto *definingOp = value.getDefiningOp();
2439 return (definingOp && opsToMove.count(definingOp)) ||
2440 warpOp.isDefinedOutsideOfRegion(value);
2441 };
2442
2443 // Do not use walk here, as we do not want to go into nested regions and hoist
2444 // operations from there.
2445 for (auto &op : body->without_terminator()) {
2446 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
2447 return isa<VectorType>(result.getType());
2448 });
2449 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
2450 opsToMove.insert(&op);
2451 }
2452
2453 // Move all the ops marked as uniform outside of the region.
2454 for (Operation *op : opsToMove)
2455 op->moveBefore(warpOp);
2456}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:376
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setOperand(unsigned idx, Value value)
Definition Operation.h:377
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:875
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
unsigned getNumOperands()
Definition Operation.h:372
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:441
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.