MLIR 23.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1//===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "mlir/IR/AffineExpr.h"
18#include "mlir/IR/Attributes.h"
22#include "llvm/ADT/SetVector.h"
23#include "llvm/ADT/SmallVectorExtras.h"
24#include "llvm/Support/FormatVariadic.h"
25#include <utility>
26
27using namespace mlir;
28using namespace mlir::vector;
29using namespace mlir::gpu;
30
31/// Currently the distribution map is implicit based on the vector shape. In the
32/// future it will be part of the op.
33/// Example:
34/// ```
35/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
36/// ...
37/// gpu.yield %3 : vector<32x16x64xf32>
38/// }
39/// ```
40/// Would have an implicit map of:
41/// `(d0, d1, d2) -> (d0, d2)`
42static AffineMap calculateImplicitMap(VectorType sequentialType,
43 VectorType distributedType) {
45 perm.reserve(1);
46 // Check which dimensions of the sequential type are different than the
47 // dimensions of the distributed type to know the distributed dimensions. Then
48 // associate each distributed dimension to an ID in order.
49 for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50 if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
51 perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
52 }
53 auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
54 distributedType.getContext());
55 return map;
56}
57
58/// Given a sequential and distributed vector type, returns the distributed
59/// dimension. This function expects that only a single dimension is
60/// distributed.
61static int getDistributedDim(VectorType sequentialType,
62 VectorType distributedType) {
63 assert(sequentialType.getRank() == distributedType.getRank() &&
64 "sequential and distributed vector types must have the same rank");
65 int64_t distributedDim = -1;
66 for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
67 if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
68 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
69 // support distributing multiple dimensions in the future.
70 assert(distributedDim == -1 && "found multiple distributed dims");
71 distributedDim = i;
72 }
73 }
74 return distributedDim;
75}
76
77namespace {
78
79/// Helper struct to create the load / store operations that permit transit
80/// through the parallel / sequential and the sequential / parallel boundaries
81/// when performing `rewriteWarpOpToScfFor`.
82///
83/// The vector distribution dimension is inferred from the vector types.
84struct DistributedLoadStoreHelper {
85 DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
86 Value laneId, Value zero)
87 : sequentialVal(sequentialVal), distributedVal(distributedVal),
88 laneId(laneId), zero(zero) {
89 sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
90 distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
91 if (sequentialVectorType && distributedVectorType)
92 distributionMap =
93 calculateImplicitMap(sequentialVectorType, distributedVectorType);
94 }
95
96 Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
97 int64_t distributedSize = distributedVectorType.getDimSize(index);
98 AffineExpr tid = getAffineSymbolExpr(0, b.getContext());
99 return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
100 ArrayRef<Value>{laneId});
101 }
102
103 /// Create a store during the process of distributing the
104 /// `vector.warp_execute_on_thread_0` op.
105 /// Vector distribution assumes the following convention regarding the
106 /// temporary buffers that are created to transition values. This **must**
107 /// be properly specified in the `options.warpAllocationFn`:
108 /// 1. scalars of type T transit through a memref<1xT>.
109 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
110 Operation *buildStore(RewriterBase &b, Location loc, Value val,
111 Value buffer) {
112 assert((val == distributedVal || val == sequentialVal) &&
113 "Must store either the preregistered distributed or the "
114 "preregistered sequential value.");
115 // Scalar case can directly use memref.store.
116 if (!isa<VectorType>(val.getType()))
117 return memref::StoreOp::create(b, loc, val, buffer, zero);
118
119 // Vector case must use vector::TransferWriteOp which will later lower to
120 // vector.store of memref.store depending on further lowerings.
121 int64_t rank = sequentialVectorType.getRank();
122 SmallVector<Value> indices(rank, zero);
123 if (val == distributedVal) {
124 for (auto dimExpr : distributionMap.getResults()) {
125 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126 indices[index] = buildDistributedOffset(b, loc, index);
127 }
128 }
129 SmallVector<bool> inBounds(indices.size(), true);
130 return vector::TransferWriteOp::create(
131 b, loc, val, buffer, indices,
132 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
133 }
134
135 /// Create a load during the process of distributing the
136 /// `vector.warp_execute_on_thread_0` op.
137 /// Vector distribution assumes the following convention regarding the
138 /// temporary buffers that are created to transition values. This **must**
139 /// be properly specified in the `options.warpAllocationFn`:
140 /// 1. scalars of type T transit through a memref<1xT>.
141 /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
142 ///
143 /// When broadcastMode is true, the load is not distributed to account for
144 /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
145 ///
146 /// Example:
147 ///
148 /// ```
149 /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
150 /// gpu.yield %cst : f32
151 /// }
152 /// // Both types are f32. The constant %cst is broadcasted to all lanes.
153 /// ```
154 /// This behavior described in more detail in the documentation of the op.
155 Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
156
157 // Scalar case can directly use memref.store.
158 if (!isa<VectorType>(type))
159 return memref::LoadOp::create(b, loc, buffer, zero);
160
161 // Other cases must be vector atm.
162 // Vector case must use vector::TransferReadOp which will later lower to
163 // vector.read of memref.read depending on further lowerings.
164 assert((type == distributedVectorType || type == sequentialVectorType) &&
165 "Must store either the preregistered distributed or the "
166 "preregistered sequential type.");
167 SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
168 if (type == distributedVectorType) {
169 for (auto dimExpr : distributionMap.getResults()) {
170 int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171 indices[index] = buildDistributedOffset(b, loc, index);
172 }
173 }
174 SmallVector<bool> inBounds(indices.size(), true);
175 return vector::TransferReadOp::create(
176 b, loc, cast<VectorType>(type), buffer, indices,
177 /*padding=*/std::nullopt,
178 ArrayRef<bool>(inBounds.begin(), inBounds.end()));
179 }
180
181 Value sequentialVal, distributedVal, laneId, zero;
182 VectorType sequentialVectorType, distributedVectorType;
183 AffineMap distributionMap;
184};
185
186} // namespace
187
188// Clones `op` into a new operation that takes `operands` and returns
189// `resultTypes`.
191 Location loc, Operation *op,
192 ArrayRef<Value> operands,
193 ArrayRef<Type> resultTypes) {
194 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
195 op->getAttrs());
196 return rewriter.create(res);
197}
198
199namespace {
200
201/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
202/// thread `laneId` executes the entirety of the computation.
203///
204/// After the transformation:
205/// - the IR within the scf.if op can be thought of as executing sequentially
206/// (from the point of view of threads along `laneId`).
207/// - the IR outside of the scf.if op can be thought of as executing in
208/// parallel (from the point of view of threads along `laneId`).
209///
210/// Values that need to transit through the parallel / sequential and the
211/// sequential / parallel boundaries do so via reads and writes to a temporary
212/// memory location.
213///
214/// The transformation proceeds in multiple steps:
215/// 1. Create the scf.if op.
216/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
217/// within the scf.if to transit the values captured from above.
218/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
219/// consistent within the scf.if.
220/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
221/// 5. Insert appropriate writes within scf.if and reads after the scf.if to
222/// transit the values returned by the op.
223/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
224/// consistent after the scf.if.
225/// 7. Perform late cleanups.
226///
227/// All this assumes the vector distribution occurs along the most minor
228/// distributed vector dimension.
229struct WarpOpToScfIfPattern : public WarpDistributionPattern {
230 WarpOpToScfIfPattern(MLIRContext *context,
231 const WarpExecuteOnLane0LoweringOptions &options,
232 PatternBenefit benefit = 1)
233 : WarpDistributionPattern(context, benefit), options(options) {}
234
235 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
236 PatternRewriter &rewriter) const override {
237 assert(warpOp.getBodyRegion().hasOneBlock() &&
238 "expected WarpOp with single block");
239 Block *warpOpBody = &warpOp.getBodyRegion().front();
240 Location loc = warpOp.getLoc();
241
242 // Passed all checks. Start rewriting.
243 OpBuilder::InsertionGuard g(rewriter);
244 rewriter.setInsertionPoint(warpOp);
245
246 // Step 1: Create scf.if op.
247 Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
248 Value isLane0 = arith::CmpIOp::create(
249 rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
250 auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
251 /*withElseRegion=*/false);
252 rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
253
254 // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
255 // reads within the scf.if to transit the values captured from above.
256 SmallVector<Value> bbArgReplacements;
257 for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
258 Value sequentialVal = warpOpBody->getArgument(it.index());
259 Value distributedVal = it.value();
260 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
261 warpOp.getLaneid(), c0);
262
263 // Create buffer before the ifOp.
264 rewriter.setInsertionPoint(ifOp);
265 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
266 sequentialVal.getType());
267 // Store distributed vector into buffer, before the ifOp.
268 helper.buildStore(rewriter, loc, distributedVal, buffer);
269 // Load sequential vector from buffer, inside the ifOp.
270 rewriter.setInsertionPointToStart(ifOp.thenBlock());
271 bbArgReplacements.push_back(
272 helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
273 }
274
275 // Step 3. Insert sync after all the stores and before all the loads.
276 if (!warpOp.getArgs().empty()) {
277 rewriter.setInsertionPoint(ifOp);
278 options.warpSynchronizationFn(loc, rewriter, warpOp);
279 }
280
281 // Step 4. Move body of warpOp to ifOp.
282 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
283
284 // Step 5. Insert appropriate writes within scf.if and reads after the
285 // scf.if to transit the values returned by the op.
286 // TODO: at this point, we can reuse the shared memory from previous
287 // buffers.
288 SmallVector<Value> replacements;
289 auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
290 Location yieldLoc = yieldOp.getLoc();
291 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
292 Value sequentialVal = it.value();
293 Value distributedVal = warpOp->getResult(it.index());
294 DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
295 warpOp.getLaneid(), c0);
296
297 // Create buffer before the ifOp.
298 rewriter.setInsertionPoint(ifOp);
299 Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
300 sequentialVal.getType());
301
302 // Store yielded value into buffer, inside the ifOp, before the
303 // terminator.
304 rewriter.setInsertionPoint(yieldOp);
305 helper.buildStore(rewriter, loc, sequentialVal, buffer);
306
307 // Load distributed value from buffer, after the warpOp.
308 rewriter.setInsertionPointAfter(ifOp);
309 // Result type and yielded value type are the same. This is a broadcast.
310 // E.g.:
311 // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
312 // gpu.yield %cst : f32
313 // }
314 // Both types are f32. The constant %cst is broadcasted to all lanes.
315 // This is described in more detail in the documentation of the op.
316 replacements.push_back(
317 helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
318 }
319
320 // Step 6. Insert sync after all the stores and before all the loads.
321 if (!yieldOp.getOperands().empty()) {
322 rewriter.setInsertionPointAfter(ifOp);
323 options.warpSynchronizationFn(loc, rewriter, warpOp);
324 }
325
326 // Step 7. Delete terminator and add empty scf.yield.
327 rewriter.eraseOp(yieldOp);
328 rewriter.setInsertionPointToEnd(ifOp.thenBlock());
329 scf::YieldOp::create(rewriter, yieldLoc);
330
331 // Compute replacements for WarpOp results.
332 rewriter.replaceOp(warpOp, replacements);
333
334 return success();
335 }
336
337private:
338 const WarpExecuteOnLane0LoweringOptions &options;
339};
340
341/// Return the distributed vector type based on the original type and the
342/// distribution map. The map is expected to have a dimension equal to the
343/// original type rank and should be a projection where the results are the
344/// distributed dimensions. If the number of results is zero there is no
345/// distribution (i.e. original type is returned).
346/// Otherwise, The number of results should be equal to the number
347/// of warp sizes which is currently limited to 1.
348/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
349/// and a warp size of 16 would distribute the second dimension (associated to
350/// d1) and return vector<16x2x64>
351static VectorType getDistributedType(VectorType originalType, AffineMap map,
352 int64_t warpSize) {
353 // If the map has zero results, return the original type.
354 if (map.getNumResults() == 0)
355 return originalType;
356 SmallVector<int64_t> targetShape(originalType.getShape());
357 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
358 unsigned position = map.getDimPosition(i);
359 if (targetShape[position] % warpSize != 0) {
360 if (warpSize % targetShape[position] != 0) {
361 return VectorType();
362 }
363 warpSize /= targetShape[position];
364 targetShape[position] = 1;
365 continue;
366 }
367 targetShape[position] = targetShape[position] / warpSize;
368 warpSize = 1;
369 break;
370 }
371 if (warpSize != 1) {
372 return VectorType();
373 }
374 VectorType targetType =
375 VectorType::get(targetShape, originalType.getElementType());
376 return targetType;
377}
378
379/// Given a warpOp that contains ops with regions, the corresponding op's
380/// "inner" region and the distributionMapFn, get all values used by the op's
381/// region that are defined within the warpOp, but outside the inner region.
382/// Return the set of values, their types and their distributed types.
383std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
385getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
386 DistributionMapFn distributionMapFn) {
387 llvm::SmallSetVector<Value, 32> escapingValues;
388 SmallVector<Type> escapingValueTypes;
389 SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
390 if (innerRegion.empty())
391 return {std::move(escapingValues), std::move(escapingValueTypes),
392 std::move(escapingValueDistTypes)};
393 mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
394 Operation *parent = operand->get().getParentRegion()->getParentOp();
395 if (warpOp->isAncestor(parent)) {
396 if (!escapingValues.insert(operand->get()))
397 return;
398 Type distType = operand->get().getType();
399 if (auto vecType = dyn_cast<VectorType>(distType)) {
400 AffineMap map = distributionMapFn(operand->get());
401 distType = getDistributedType(vecType, map,
402 map.isEmpty() ? 1 : warpOp.getWarpSize());
403 }
404 escapingValueTypes.push_back(operand->get().getType());
405 escapingValueDistTypes.push_back(distType);
406 }
407 });
408 return {std::move(escapingValues), std::move(escapingValueTypes),
409 std::move(escapingValueDistTypes)};
410}
411
412/// Distribute transfer_write ops based on the affine map returned by
413/// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
414/// will not be distributed (it should be less than the warp size).
415///
416/// Example:
417/// ```
418/// %0 = gpu.warp_execute_on_lane_0(%id){
419/// ...
420/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
421/// gpu.yield
422/// }
423/// ```
424/// To
425/// ```
426/// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
427/// ...
428/// gpu.yield %v : vector<32xf32>
429/// }
430/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
431struct WarpOpTransferWrite : public WarpDistributionPattern {
432 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
433 unsigned maxNumElementsToExtract, PatternBenefit b = 1)
434 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
435 maxNumElementsToExtract(maxNumElementsToExtract) {}
436
437 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
438 /// are multiples of the distribution ratio are supported at the moment.
439 LogicalResult tryDistributeOp(RewriterBase &rewriter,
440 vector::TransferWriteOp writeOp,
441 WarpExecuteOnLane0Op warpOp) const {
442 VectorType writtenVectorType = writeOp.getVectorType();
443
444 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
445 // to separate it from the rest.
446 if (writtenVectorType.getRank() == 0)
447 return failure();
448
449 // 2. Compute the distributed type.
450 AffineMap map = distributionMapFn(writeOp.getVector());
451 VectorType targetType =
452 getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
453 if (!targetType)
454 return failure();
455
456 // 2.5 Compute the distributed type for the new mask;
457 VectorType maskType;
458 if (writeOp.getMask()) {
459 // TODO: Distribution of masked writes with non-trivial permutation maps
460 // requires the distribution of the mask to elementwise match the
461 // distribution of the permuted written vector. Currently the details
462 // of which lane is responsible for which element is captured strictly
463 // by shape information on the warp op, and thus requires materializing
464 // the permutation in IR.
465 if (!writeOp.getPermutationMap().isMinorIdentity())
466 return failure();
467 maskType =
468 getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
469 }
470
471 // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
472 // the rest.
473 vector::TransferWriteOp newWriteOp =
474 cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
475
476 // 4. Reindex the write using the distribution map.
477 auto newWarpOp =
478 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
479
480 // Delinearize the lane id based on the way threads are divided across the
481 // vector. To get the number of threads per vector dimension, divide the
482 // sequential size by the distributed size along each dim.
483 rewriter.setInsertionPoint(newWriteOp);
484 SmallVector<OpFoldResult> delinearizedIdSizes;
485 for (auto [seqSize, distSize] :
486 llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
487 assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
488 delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
489 }
490 SmallVector<Value> delinearized;
491 if (map.getNumResults() > 1) {
492 delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
493 rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
494 delinearizedIdSizes)
495 .getResults();
496 } else {
497 // If there is only one map result, we can elide the delinearization
498 // op and use the lane id directly.
499 delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
500 }
501
502 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
503 Location loc = newWriteOp.getLoc();
504 SmallVector<Value> indices(newWriteOp.getIndices().begin(),
505 newWriteOp.getIndices().end());
506 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
507 AffineExpr d0, d1;
508 bindDims(newWarpOp.getContext(), d0, d1);
509 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
510 if (!indexExpr)
511 continue;
512 unsigned indexPos = indexExpr.getPosition();
513 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
514 Value laneId = delinearized[vectorPos];
515 auto scale =
516 rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
518 rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
519 }
520 newWriteOp.getIndicesMutable().assign(indices);
521
522 return success();
523 }
524
525 /// Extract TransferWriteOps of vector<1x> into a separate warp op.
526 LogicalResult tryExtractOp(RewriterBase &rewriter,
527 vector::TransferWriteOp writeOp,
528 WarpExecuteOnLane0Op warpOp) const {
529 Location loc = writeOp.getLoc();
530 VectorType vecType = writeOp.getVectorType();
531
532 if (vecType.getNumElements() > maxNumElementsToExtract) {
533 return rewriter.notifyMatchFailure(
534 warpOp,
535 llvm::formatv(
536 "writes more elements ({0}) than allowed to extract ({1})",
537 vecType.getNumElements(), maxNumElementsToExtract));
538 }
539
540 // Do not process warp ops that contain only TransferWriteOps.
541 if (llvm::all_of(warpOp.getOps(),
542 llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
543 return failure();
544
545 SmallVector<Value> yieldValues = {writeOp.getVector()};
546 SmallVector<Type> retTypes = {vecType};
547 SmallVector<size_t> newRetIndices;
548 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
549 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
550 rewriter.setInsertionPointAfter(newWarpOp);
551
552 // Create a second warp op that contains only writeOp.
553 auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc, TypeRange(),
554 newWarpOp.getLaneid(),
555 newWarpOp.getWarpSize());
556 Block &body = secondWarpOp.getBodyRegion().front();
557 rewriter.setInsertionPointToStart(&body);
558 auto newWriteOp =
559 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
560 newWriteOp.getValueToStoreMutable().assign(
561 newWarpOp.getResult(newRetIndices[0]));
562 rewriter.eraseOp(writeOp);
563 gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
564 return success();
565 }
566
567 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
568 PatternRewriter &rewriter) const override {
569 gpu::YieldOp yield = warpOp.getTerminator();
570 Operation *lastNode = yield->getPrevNode();
571 auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
572 if (!writeOp)
573 return failure();
574
575 Value maybeMask = writeOp.getMask();
576 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
577 return writeOp.getVector() == value ||
578 (maybeMask && maybeMask == value) ||
579 warpOp.isDefinedOutsideOfRegion(value);
580 }))
581 return failure();
582
583 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
584 return success();
585
586 // Masked writes not supported for extraction.
587 if (writeOp.getMask())
588 return failure();
589
590 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
591 return success();
592
593 return failure();
594 }
595
596private:
597 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
598 /// execute op with the proper return type. The new write op is updated to
599 /// write the result of the new warp execute op. The old `writeOp` is deleted.
600 vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
601 WarpExecuteOnLane0Op warpOp,
602 vector::TransferWriteOp writeOp,
603 VectorType targetType,
604 VectorType maybeMaskType) const {
605 assert(writeOp->getParentOp() == warpOp &&
606 "write must be nested immediately under warp");
607 OpBuilder::InsertionGuard g(rewriter);
608 SmallVector<size_t> newRetIndices;
609 WarpExecuteOnLane0Op newWarpOp;
610 if (maybeMaskType) {
612 rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
613 TypeRange{targetType, maybeMaskType}, newRetIndices);
614 } else {
616 rewriter, warpOp, ValueRange{{writeOp.getVector()}},
617 TypeRange{targetType}, newRetIndices);
618 }
619 rewriter.setInsertionPointAfter(newWarpOp);
620 auto newWriteOp =
621 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
622 rewriter.eraseOp(writeOp);
623 newWriteOp.getValueToStoreMutable().assign(
624 newWarpOp.getResult(newRetIndices[0]));
625 if (maybeMaskType)
626 newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
627 return newWriteOp;
628 }
629
630 DistributionMapFn distributionMapFn;
631 unsigned maxNumElementsToExtract = 1;
632};
633
634/// Sink out elementwise op feeding into a warp op yield.
635/// ```
636/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
637/// ...
638/// %3 = arith.addf %1, %2 : vector<32xf32>
639/// gpu.yield %3 : vector<32xf32>
640/// }
641/// ```
642/// To
643/// ```
644/// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
645/// vector<1xf32>, vector<1xf32>) {
646/// ...
647/// %4 = arith.addf %2, %3 : vector<32xf32>
648/// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
649/// vector<32xf32>
650/// }
651/// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
652struct WarpOpElementwise : public WarpDistributionPattern {
653 using Base::Base;
654 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
655 PatternRewriter &rewriter) const override {
656 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
658 });
659 if (!yieldOperand)
660 return failure();
661
662 Operation *elementWise = yieldOperand->get().getDefiningOp();
663 unsigned operandIndex = yieldOperand->getOperandNumber();
664 Value distributedVal = warpOp.getResult(operandIndex);
665 SmallVector<Value> yieldValues;
666 SmallVector<Type> retTypes;
667 Location loc = warpOp.getLoc();
668 for (OpOperand &operand : elementWise->getOpOperands()) {
669 Type targetType;
670 if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
671 // If the result type is a vector, the operands must also be vectors.
672 auto operandType = cast<VectorType>(operand.get().getType());
673 targetType =
674 VectorType::get(vecType.getShape(), operandType.getElementType());
675 } else {
676 auto operandType = operand.get().getType();
677 assert(!isa<VectorType>(operandType) &&
678 "unexpected yield of vector from op with scalar result type");
679 targetType = operandType;
680 }
681 retTypes.push_back(targetType);
682 yieldValues.push_back(operand.get());
683 }
684 SmallVector<size_t> newRetIndices;
685 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
686 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
687 rewriter.setInsertionPointAfter(newWarpOp);
688 SmallVector<Value> newOperands(elementWise->getOperands().begin(),
689 elementWise->getOperands().end());
690 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
691 newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
692 }
693 OpBuilder::InsertionGuard g(rewriter);
694 rewriter.setInsertionPointAfter(newWarpOp);
695 Operation *newOp = cloneOpWithOperandsAndTypes(
696 rewriter, loc, elementWise, newOperands,
697 {newWarpOp.getResult(operandIndex).getType()});
698 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
699 newOp->getResult(0));
700 return success();
701 }
702};
703
704/// Sink out splat constant op feeding into a warp op yield.
705/// ```
706/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
707/// ...
708/// %cst = arith.constant dense<2.0> : vector<32xf32>
709/// gpu.yield %cst : vector<32xf32>
710/// }
711/// ```
712/// To
713/// ```
714/// gpu.warp_execute_on_lane_0(%arg0 {
715/// ...
716/// }
717/// %0 = arith.constant dense<2.0> : vector<1xf32>
718struct WarpOpConstant : public WarpDistributionPattern {
719 using Base::Base;
720 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
721 PatternRewriter &rewriter) const override {
722 OpOperand *yieldOperand =
723 getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
724 if (!yieldOperand)
725 return failure();
726 auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
727 auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
728 if (!dense)
729 return failure();
730 // Notify the rewriter that the warp op is changing (see the comment on
731 // the WarpOpTransferRead pattern).
732 rewriter.startOpModification(warpOp);
733 unsigned operandIndex = yieldOperand->getOperandNumber();
734 Attribute scalarAttr = dense.getSplatValue<Attribute>();
735 auto newAttr = DenseElementsAttr::get(
736 cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
737 Location loc = warpOp.getLoc();
738 rewriter.setInsertionPointAfter(warpOp);
739 Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
740 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
741 rewriter.finalizeOpModification(warpOp);
742 return success();
743 }
744};
745
746/// Sink out step op feeding into a warp op yield.
747/// Vector step op is treated similar to arith.constant, apart from
748/// the result that represents a sequence [0, vec_size).
749/// Due to the to vec_size == warp_size limitation,
750/// we can simply wrap the lane id into a vector (i.e., broadcast).
751/// Supporting vec_size != warp_size may involve preserving the step
752/// result and using additional arith ops (the exact details are TBD).
753/// ```
754/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
755/// ...
756/// %cst = vector.step : vector<32xindex>
757/// gpu.yield %cst : vector<1xindex>
758/// }
759/// ```
760/// To
761/// ```
762/// gpu.warp_execute_on_lane_0(%arg0) {
763/// ...
764/// }
765/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
766struct WarpOpStep final : public WarpDistributionPattern {
767 using Base::Base;
768 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
769 PatternRewriter &rewriter) const override {
770 OpOperand *yieldOperand =
771 getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
772 if (!yieldOperand)
773 return failure();
774 const unsigned operandIdx = yieldOperand->getOperandNumber();
775 auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
776 VectorType resTy = stepOp.getResult().getType();
777 if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize()))
778 return rewriter.notifyMatchFailure(
779 warpOp,
780 llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
781 resTy.getNumElements(), warpOp.getWarpSize()));
782 VectorType newVecTy =
783 cast<VectorType>(warpOp.getResult(operandIdx).getType());
784 rewriter.setInsertionPointAfter(warpOp);
785 Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
786 newVecTy, warpOp.getLaneid());
787 rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
788 return success();
789 }
790};
791
792/// Sink out transfer_read op feeding into a warp op yield.
793/// ```
794/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
795/// ...
796// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
797// vector<32xf32>
798/// gpu.yield %2 : vector<32xf32>
799/// }
800/// ```
801/// To
802/// ```
803/// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
804/// vector<1xf32>, vector<1xf32>) {
805/// ...
806/// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
807/// vector<32xf32> gpu.yield %2 : vector<32xf32>
808/// }
809/// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
810struct WarpOpTransferRead : public WarpDistributionPattern {
811 using Base::Base;
812 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
813 PatternRewriter &rewriter) const override {
814 // Try to find a distributable yielded read. Note that this pattern can
815 // still fail at the end after distribution, in which case this might have
816 // missed another distributable read.
817 OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
818 // Don't duplicate transfer_read ops when distributing.
819 return isa<vector::TransferReadOp>(op) && op->hasOneUse();
820 });
821 if (!operand)
822 return rewriter.notifyMatchFailure(
823 warpOp, "warp result is not a vector.transfer_read op");
824 auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
825
826 // Source must be defined outside of the region.
827 if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
828 return rewriter.notifyMatchFailure(
829 read, "source must be defined outside of the region");
830
831 unsigned operandIndex = operand->getOperandNumber();
832 Value distributedVal = warpOp.getResult(operandIndex);
833
834 SmallVector<Value, 4> indices(read.getIndices().begin(),
835 read.getIndices().end());
836 auto sequentialType = cast<VectorType>(read.getResult().getType());
837 auto distributedType = cast<VectorType>(distributedVal.getType());
838 AffineMap map = calculateImplicitMap(sequentialType, distributedType);
839 AffineMap indexMap = map.compose(read.getPermutationMap());
840
841 // Try to delinearize the lane ID to match the rank expected for
842 // distribution.
843 SmallVector<Value> delinearizedIds;
844 if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
845 distributedType.getShape(), warpOp.getWarpSize(),
846 warpOp.getLaneid(), delinearizedIds)) {
847 return rewriter.notifyMatchFailure(
848 read, "cannot delinearize lane ID for distribution");
849 }
850 assert(!delinearizedIds.empty() || map.getNumResults() == 0);
851
852 // Distribute indices and the mask (if present).
853 OpBuilder::InsertionGuard g(rewriter);
854 SmallVector<Value> additionalResults(indices.begin(), indices.end());
855 SmallVector<Type> additionalResultTypes(indices.size(),
856 rewriter.getIndexType());
857 additionalResults.push_back(read.getPadding());
858 additionalResultTypes.push_back(read.getPadding().getType());
859
860 bool hasMask = false;
861 if (read.getMask()) {
862 hasMask = true;
863 // TODO: Distribution of masked reads with non-trivial permutation maps
864 // requires the distribution of the mask to elementwise match the
865 // distribution of the permuted written vector. Currently the details
866 // of which lane is responsible for which element is captured strictly
867 // by shape information on the warp op, and thus requires materializing
868 // the permutation in IR.
869 if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
870 return rewriter.notifyMatchFailure(
871 read, "non-trivial permutation maps not supported");
872 VectorType maskType =
873 getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
874 additionalResults.push_back(read.getMask());
875 additionalResultTypes.push_back(maskType);
876 }
877
878 SmallVector<size_t> newRetIndices;
879 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
880 rewriter, warpOp, additionalResults, additionalResultTypes,
881 newRetIndices);
882 distributedVal = newWarpOp.getResult(operandIndex);
883
884 // Distributed indices were appended first.
885 SmallVector<Value> newIndices;
886 for (int64_t i = 0, e = indices.size(); i < e; ++i)
887 newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
888
889 rewriter.setInsertionPointAfter(newWarpOp);
890 for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
891 AffineExpr d0, d1;
892 bindDims(read.getContext(), d0, d1);
893 auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
894 if (!indexExpr)
895 continue;
896 unsigned indexPos = indexExpr.getPosition();
897 unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
898 int64_t scale = distributedType.getDimSize(vectorPos);
899 newIndices[indexPos] = affine::makeComposedAffineApply(
900 rewriter, read.getLoc(), d0 + scale * d1,
901 {newIndices[indexPos], delinearizedIds[vectorPos]});
902 }
903
904 // Distributed padding value was appended right after the indices.
905 Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
906 // Distributed mask value was added at the end (if the op has a mask).
907 Value newMask =
908 hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
909 : Value();
910 auto newRead = vector::TransferReadOp::create(
911 rewriter, read.getLoc(), distributedVal.getType(), read.getBase(),
912 newIndices, read.getPermutationMapAttr(), newPadding, newMask,
913 read.getInBoundsAttr());
914
915 rewriter.replaceAllUsesWith(distributedVal, newRead);
916 return success();
917 }
918};
919
920/// Remove any result that has no use along with the matching yieldOp operand.
921// TODO: Move this in WarpExecuteOnLane0Op canonicalization.
922struct WarpOpDeadResult : public WarpDistributionPattern {
923 using Base::Base;
924 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
925 PatternRewriter &rewriter) const override {
926 SmallVector<Type> newResultTypes;
927 newResultTypes.reserve(warpOp->getNumResults());
928 SmallVector<Value> newYieldValues;
929 newYieldValues.reserve(warpOp->getNumResults());
930 DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
931 DenseMap<OpResult, int64_t> dedupResultPositionMap;
932 gpu::YieldOp yield = warpOp.getTerminator();
933
934 // Some values may be yielded multiple times and correspond to multiple
935 // results. Deduplicating occurs by taking each result with its matching
936 // yielded value, and:
937 // 1. recording the unique first position at which the value with uses is
938 // yielded.
939 // 2. recording for the result, the first position at which the dedup'ed
940 // value is yielded.
941 // 3. skipping from the new result types / new yielded values any result
942 // that has no use or whose yielded value has already been seen.
943 for (OpResult result : warpOp.getResults()) {
944 if (result.use_empty())
945 continue;
946 Value yieldOperand = yield.getOperand(result.getResultNumber());
947 auto it = dedupYieldOperandPositionMap.insert(
948 std::make_pair(yieldOperand, newResultTypes.size()));
949 dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
950 if (!it.second)
951 continue;
952 newResultTypes.push_back(result.getType());
953 newYieldValues.push_back(yieldOperand);
954 }
955 // No modification, exit early.
956 if (yield.getNumOperands() == newYieldValues.size())
957 return failure();
958 // Move the body of the old warpOp to a new warpOp.
959 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
960 rewriter, warpOp, newYieldValues, newResultTypes);
961
962 // Simplify the new warp op after dropping dead results.
963 newWarpOp.getBody()->walk([&](Operation *op) {
964 if (isOpTriviallyDead(op))
965 rewriter.eraseOp(op);
966 });
967
968 // Replace results of the old warpOp by the new, deduplicated results.
969 SmallVector<Value> newValues;
970 newValues.reserve(warpOp->getNumResults());
971 for (OpResult result : warpOp.getResults()) {
972 if (result.use_empty())
973 newValues.push_back(Value());
974 else
975 newValues.push_back(
976 newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
977 }
978 rewriter.replaceOp(warpOp, newValues);
979 return success();
980 }
981};
982
983// If an operand is directly yielded out of the region we can forward it
984// directly and it doesn't need to go through the region.
985struct WarpOpForwardOperand : public WarpDistributionPattern {
986 using Base::Base;
987 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
988 PatternRewriter &rewriter) const override {
989 gpu::YieldOp yield = warpOp.getTerminator();
990 Value valForwarded;
991 unsigned resultIndex;
992 for (OpOperand &operand : yield->getOpOperands()) {
993 Value result = warpOp.getResult(operand.getOperandNumber());
994 if (result.use_empty())
995 continue;
996
997 // Assume all the values coming from above are uniform.
998 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
999 if (result.getType() != operand.get().getType())
1000 continue;
1001 valForwarded = operand.get();
1002 resultIndex = operand.getOperandNumber();
1003 break;
1004 }
1005 auto arg = dyn_cast<BlockArgument>(operand.get());
1006 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1007 continue;
1008 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1009 if (result.getType() != warpOperand.getType())
1010 continue;
1011 valForwarded = warpOperand;
1012 resultIndex = operand.getOperandNumber();
1013 break;
1014 }
1015 if (!valForwarded)
1016 return failure();
1017 // Notify the rewriter that the warp op is changing (see the comment on
1018 // the WarpOpTransferRead pattern).
1019 rewriter.startOpModification(warpOp);
1020 rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
1021 rewriter.finalizeOpModification(warpOp);
1022 return success();
1023 }
1024};
1025
1026struct WarpOpBroadcast : public WarpDistributionPattern {
1027 using Base::Base;
1028 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1029 PatternRewriter &rewriter) const override {
1030 OpOperand *operand =
1031 getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1032 if (!operand)
1033 return failure();
1034 unsigned int operandNumber = operand->getOperandNumber();
1035 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
1036 Location loc = broadcastOp.getLoc();
1037 auto destVecType =
1038 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1039 Value broadcastSrc = broadcastOp.getSource();
1040 Type broadcastSrcType = broadcastSrc.getType();
1041
1042 // Check that the broadcast actually spans a set of values uniformly across
1043 // all threads. In other words, check that each thread can reconstruct
1044 // their own broadcast.
1045 // For that we simply check that the broadcast we want to build makes sense.
1046 if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
1047 vector::BroadcastableToResult::Success)
1048 return failure();
1049 SmallVector<size_t> newRetIndices;
1050 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1051 rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1052 rewriter.setInsertionPointAfter(newWarpOp);
1053 Value broadcasted = vector::BroadcastOp::create(
1054 rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1055 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1056 broadcasted);
1057 return success();
1058 }
1059};
1060
1061/// Pattern to move shape cast out of the warp op. shape cast is basically a
1062/// no-op for warp distribution; we need to handle the shape though.
1063struct WarpOpShapeCast : public WarpDistributionPattern {
1064 using Base::Base;
1065 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1066 PatternRewriter &rewriter) const override {
1067 OpOperand *operand =
1068 getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1069 if (!operand)
1070 return failure();
1071
1072 auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
1073
1074 unsigned int operandNumber = operand->getOperandNumber();
1075 auto castDistributedType =
1076 cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1077 VectorType castOriginalType = oldCastOp.getSourceVectorType();
1078 VectorType castResultType = castDistributedType;
1079
1080 FailureOr<VectorType> maybeSrcType =
1081 inferDistributedSrcType(castDistributedType, castOriginalType);
1082 if (failed(maybeSrcType))
1083 return failure();
1084 castDistributedType = *maybeSrcType;
1085
1086 SmallVector<size_t> newRetIndices;
1087 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1088 rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1089 newRetIndices);
1090 rewriter.setInsertionPointAfter(newWarpOp);
1091 Value newCast = vector::ShapeCastOp::create(
1092 rewriter, oldCastOp.getLoc(), castResultType,
1093 newWarpOp->getResult(newRetIndices[0]));
1094 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1095 return success();
1096 }
1097
1098private:
1099 static FailureOr<VectorType>
1100 inferDistributedSrcType(VectorType distributedType, VectorType srcType) {
1101 unsigned distributedRank = distributedType.getRank();
1102 unsigned srcRank = srcType.getRank();
1103 if (distributedRank == srcRank)
1104 // Nothing to do.
1105 return distributedType;
1106 if (distributedRank < srcRank) {
1107 // If the distributed type has a smaller rank than the original type,
1108 // prepend with unit dimensions to make the types the same length.
1109 SmallVector<int64_t> shape(srcRank - distributedRank, 1);
1110 llvm::append_range(shape, distributedType.getShape());
1111 return VectorType::get(shape, distributedType.getElementType());
1112 }
1113 // Handle the expanding shape_cast's.
1114 //
1115 // If the casted-from type has one rank, we can assert that the element
1116 // count in that rank will match the full thread-level element count of
1117 // the yielded type.
1118 // Note that getNumElements() will correctly "flatten" the shape of the
1119 // specific shape_cast's distributed type (its distribution may be
1120 // different from the overall warp size, e.g. if the cast is applied to
1121 // a result of a gather).
1122 if (srcRank == 1)
1123 return VectorType::get(distributedType.getNumElements(),
1124 srcType.getElementType());
1125 // Try to strip leading unit dimensions to match the ranks. We bail out
1126 // for more complex tile sizes, because those would require us to
1127 // determine the specific distribution parameters to threads, which is
1128 // unfeasible within this pattern.
1129 unsigned excessDims = distributedRank - srcRank;
1130 ArrayRef<int64_t> shape = distributedType.getShape();
1131 if (!llvm::all_of(shape.take_front(excessDims),
1132 [](int64_t d) { return d == 1; }))
1133 return failure();
1134 return VectorType::get(shape.drop_front(excessDims),
1135 distributedType.getElementType());
1136 }
1137};
1138
1139/// Sink out vector.create_mask / vector.constant_mask op feeding into a warp op
1140/// yield.
1141/// ```
1142/// %0 = ...
1143/// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1144/// ...
1145/// %mask = vector.create_mask %0 : vector<32xi1>
1146/// // or %mask = vector.constant_mask[2] : vector<32xi1>
1147/// gpu.yield %mask : vector<32xi1>
1148/// }
1149/// ```
1150/// To
1151/// ```
1152/// %0 = ...
1153/// gpu.warp_execute_on_lane_0(%arg0) {
1154/// ...
1155/// }
1156/// %cmp = arith.cmpi ult, %laneid, %0
1157/// %ub = arith.select %cmp, %c0, %c1
1158/// %1 = vector.create_mask %ub : vector<1xi1>
1159template <typename OpType,
1160 typename = std::enable_if_t<llvm::is_one_of<
1161 OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
1162struct WarpOpCreateMask : public WarpDistributionPattern {
1163 using Base::Base;
1164 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1165 PatternRewriter &rewriter) const override {
1166 OpOperand *yieldOperand = getWarpResult(warpOp, (llvm::IsaPred<OpType>));
1167 if (!yieldOperand)
1168 return failure();
1169
1170 Operation *mask = yieldOperand->get().getDefiningOp<OpType>();
1171
1172 // Early exit if any values needed for calculating the new mask indices
1173 // are defined inside the warp op.
1174 if (mask->getOperands().size() &&
1175 !llvm::all_of(mask->getOperands(), [&](Value value) {
1176 return warpOp.isDefinedOutsideOfRegion(value);
1177 }))
1178 return failure();
1179
1180 Location loc = mask->getLoc();
1181 unsigned operandIndex = yieldOperand->getOperandNumber();
1182
1183 auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1184 VectorType seqType = cast<VectorType>(mask->getResult(0).getType());
1185 ArrayRef<int64_t> seqShape = seqType.getShape();
1186 ArrayRef<int64_t> distShape = distType.getShape();
1187 SmallVector<Value> materializedOperands;
1188 if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
1189 materializedOperands.append(mask->getOperands().begin(),
1190 mask->getOperands().end());
1191 } else {
1192 auto constantMaskOp = cast<vector::ConstantMaskOp>(mask);
1193 auto dimSizes = constantMaskOp.getMaskDimSizesAttr().asArrayRef();
1194 for (auto dimSize : dimSizes)
1195 materializedOperands.push_back(
1196 arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
1197 }
1198
1199 rewriter.setInsertionPointAfter(warpOp);
1200
1201 // Delinearize the lane ID for constructing the distributed mask sizes.
1202 SmallVector<Value> delinearizedIds;
1203 if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1204 warpOp.getWarpSize(), warpOp.getLaneid(),
1205 delinearizedIds))
1206 return rewriter.notifyMatchFailure(
1207 mask, "cannot delinearize lane ID for distribution");
1208 assert(!delinearizedIds.empty());
1209
1210 // Notify the rewriter that the warp op is changing (see the comment on
1211 // the WarpOpTransferRead pattern).
1212 rewriter.startOpModification(warpOp);
1213
1214 AffineExpr s0, s1;
1215 bindSymbols(rewriter.getContext(), s0, s1);
1216 SmallVector<Value> newOperands;
1217 for (int i = 0, e = distShape.size(); i < e; ++i) {
1218 // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1219 // find the distance from the largest mask index owned by this lane to the
1220 // original mask size. `vector.create_mask` implicitly clamps mask
1221 // operands to the range [0, mask_vector_size[i]], or in other words, the
1222 // mask sizes are always in the range [0, mask_vector_size[i]).
1223 Value maskDimIdx = affine::makeComposedAffineApply(
1224 rewriter, loc, s1 - s0 * distShape[i],
1225 {delinearizedIds[i], materializedOperands[i]});
1226 newOperands.push_back(maskDimIdx);
1227 }
1228
1229 auto newMask =
1230 vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1231 rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1232 rewriter.finalizeOpModification(warpOp);
1233 return success();
1234 }
1235};
1236
1237/// Sink out insert_strided_slice op feeding into a warp op yield.
1238/// ```
1239/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
1240/// ...
1241/// %src = ... : vector<4x32xf32>
1242/// %dest = ... : vector<8x32xf32>
1243/// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1244/// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
1245/// gpu.yield %insert : vector<8x32xf32>
1246/// }
1247/// ```
1248/// To
1249/// ```
1250/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
1251/// vector<8x1xf32>) {
1252/// ...
1253/// %src = ... : vector<4x32xf32>
1254/// %dest = ... : vector<8x32xf32>
1255/// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
1256/// }
1257/// %insert = vector.insert_strided_slice %0#0, %0#1,
1258/// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
1259/// ```
1260/// NOTE: Current support assumes that both src and dest vectors are distributed
1261/// to lanes and sinking the insert op does not require any cross lane
1262/// communication.
1263struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
1264 using Base::Base;
1265 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1266 PatternRewriter &rewriter) const override {
1267 OpOperand *operand =
1268 getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1269 if (!operand)
1270 return failure();
1271 unsigned int operandNumber = operand->getOperandNumber();
1272 auto insertOp =
1273 operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1274 auto distributedType =
1275 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1276 // Distributed type must be 2D or higher.
1277 // TODO: Support 1D distributed types.
1278 if (distributedType.getRank() < 2)
1279 return rewriter.notifyMatchFailure(
1280 insertOp, "result vector type must be 2D or higher");
1281 // Find the distributed dimension of the dest vector. There should be
1282 // exactly one.
1283 auto yieldedType = cast<VectorType>(operand->get().getType());
1284 int64_t destDistributedDim =
1285 getDistributedDim(yieldedType, distributedType);
1286 assert(destDistributedDim != -1 && "could not find distributed dimension");
1287
1288 VectorType srcType = insertOp.getSourceVectorType();
1289 VectorType destType = insertOp.getDestVectorType();
1290 // Currently we require that both source (kD) and dest (nD) vectors are
1291 // distributed. This requires that distributedDim (d) is contained in the
1292 // last k dims of the dest vector (d >= n - k).
1293 // TODO: Add support for case where source vector is not distributed.
1294 int64_t sourceDistributedDim =
1295 destDistributedDim - (destType.getRank() - srcType.getRank());
1296 if (sourceDistributedDim < 0)
1297 return rewriter.notifyMatchFailure(
1298 insertOp,
1299 "distributed dimension must be in the last k dims of dest vector");
1300 // Distributed dimension must be fully inserted.
1301 if (srcType.getDimSize(sourceDistributedDim) !=
1302 destType.getDimSize(destDistributedDim))
1303 return rewriter.notifyMatchFailure(
1304 insertOp, "distributed dimension must be fully inserted");
1305 SmallVector<int64_t> newSourceDistShape(
1306 insertOp.getSourceVectorType().getShape());
1307 newSourceDistShape[sourceDistributedDim] =
1308 distributedType.getDimSize(destDistributedDim);
1309 auto newSourceTy =
1310 VectorType::get(newSourceDistShape, distributedType.getElementType());
1311 VectorType newDestTy = distributedType;
1312 SmallVector<size_t> newRetIndices;
1313 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1314 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1315 {newSourceTy, newDestTy}, newRetIndices);
1316 rewriter.setInsertionPointAfter(newWarpOp);
1317 Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1318 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1319 // Create a new insert strided slice op that inserts distributed source into
1320 // distributed dest.
1321 Value newInsert = vector::InsertStridedSliceOp::create(
1322 rewriter, insertOp.getLoc(), distributedDest.getType(),
1323 distributedSource, distributedDest, insertOp.getOffsets(),
1324 insertOp.getStrides());
1325 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
1326 return success();
1327 }
1328};
1329
1330/// Sink out extract_strided_slice op feeding into a warp op yield.
1331/// ```
1332/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
1333/// ...
1334/// %src = ... : vector<64x32xf32>
1335/// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1336/// strides = [1] : vector<64x32xf32> to vector<16x32xf32>
1337/// gpu.yield %extract : vector<16x32xf32>
1338/// }
1339/// ```
1340/// To
1341/// ```
1342/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
1343/// ...
1344/// %src = ... : vector<64x32xf32>
1345/// gpu.yield %src : vector<64x32xf32>
1346/// }
1347/// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1348/// strides = [1] : vector<64x1xf32> to vector<16x1xf32>
1349/// ```
1350/// NOTE: Current support assumes that the extraction happens only on non
1351/// distributed dimensions (does not require cross lane communication).
1352struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
1353 using Base::Base;
1354 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1355 PatternRewriter &rewriter) const override {
1356 OpOperand *operand =
1357 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1358 if (!operand)
1359 return failure();
1360 unsigned int operandNumber = operand->getOperandNumber();
1361 auto extractOp =
1362 operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
1363 auto distributedType =
1364 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1365 // Distributed type must be 2D or higher.
1366 // TODO: Support 1D distributed types.
1367 if (distributedType.getRank() < 2)
1368 return rewriter.notifyMatchFailure(
1369 extractOp, "result vector type must be 2D or higher");
1370
1371 // Find the distributed dimension. There should be exactly one.
1372 auto yieldedType = cast<VectorType>(operand->get().getType());
1373 int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1374 assert(distributedDim != -1 && "could not find distributed dimension");
1375
1376 int64_t numOfExtractedDims =
1377 static_cast<int64_t>(extractOp.getSizes().size());
1378 // If the distributed dim is included in the extracted dims, then we make
1379 // sure distributed dim is fully extracted. If distributed dim is not
1380 // included in extracted dims, it is guaranteed to be fully extracted (i.e.
1381 // distributed dim comes after all the extracted dims)
1382 // TODO: Partial extraction from distributed dimension require cross lane
1383 // communication.
1384 if (distributedDim < numOfExtractedDims) {
1385 int64_t distributedDimOffset =
1386 llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1387 .getInt();
1388 int64_t distributedDimSize =
1389 llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1390 .getInt();
1391 if (distributedDimOffset != 0 ||
1392 distributedDimSize != yieldedType.getDimSize(distributedDim))
1393 return rewriter.notifyMatchFailure(
1394 extractOp, "distributed dimension must be fully extracted");
1395 }
1396 SmallVector<int64_t> newDistributedShape(
1397 extractOp.getSourceVectorType().getShape());
1398 newDistributedShape[distributedDim] =
1399 distributedType.getDimSize(distributedDim);
1400 auto newDistributedType =
1401 VectorType::get(newDistributedShape, distributedType.getElementType());
1402 SmallVector<size_t> newRetIndices;
1403 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1404 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1405 newRetIndices);
1406 rewriter.setInsertionPointAfter(newWarpOp);
1407 SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1408 extractOp.getSizes(), [](Attribute attr) { return attr; });
1409 // Update the distributed sizes to match the distributed type.
1410 if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
1411 distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1412 distributedType.getDimSize(distributedDim));
1413
1414 // Create a new extract strided slice op that extracts from the
1415 // distributed vector.
1416 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1417 Value newExtract = vector::ExtractStridedSliceOp::create(
1418 rewriter, extractOp.getLoc(), distributedType, distributedVec,
1419 extractOp.getOffsets(),
1420 ArrayAttr::get(rewriter.getContext(), distributedSizes),
1421 extractOp.getStrides());
1422 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1423 newExtract);
1424 return success();
1425 }
1426};
1427
1428/// Pattern to move out vector.extract of single element vector. Those don't
1429/// need to be distributed and can just be propagated outside of the region.
1430struct WarpOpExtract : public WarpDistributionPattern {
1431 using Base::Base;
1432 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1433 PatternRewriter &rewriter) const override {
1434 OpOperand *operand =
1435 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1436 if (!operand)
1437 return failure();
1438 unsigned int operandNumber = operand->getOperandNumber();
1439 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1440 VectorType extractSrcType = extractOp.getSourceVectorType();
1441 Location loc = extractOp.getLoc();
1442
1443 // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1444 if (extractSrcType.getRank() <= 1) {
1445 return failure();
1446 }
1447
1448 // All following cases are 2d or higher dimensional source vectors.
1449
1450 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1451 // There is no distribution, this is a broadcast. Simply move the extract
1452 // out of the warp op.
1453 // TODO: This could be optimized. E.g., in case of a scalar result, let
1454 // one lane extract and shuffle the result to all other lanes (same as
1455 // the 1d case).
1456 SmallVector<size_t> newRetIndices;
1457 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1458 rewriter, warpOp, {extractOp.getSource()},
1459 {extractOp.getSourceVectorType()}, newRetIndices);
1460 rewriter.setInsertionPointAfter(newWarpOp);
1461 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1462 // Extract from distributed vector.
1463 Value newExtract = vector::ExtractOp::create(
1464 rewriter, loc, distributedVec, extractOp.getMixedPosition());
1465 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1466 newExtract);
1467 return success();
1468 }
1469
1470 // Find the distributed dimension. There should be exactly one.
1471 auto distributedType =
1472 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1473 auto yieldedType = cast<VectorType>(operand->get().getType());
1474 int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1475 assert(distributedDim != -1 && "could not find distributed dimension");
1476 (void)distributedDim;
1477
1478 // Yield source vector from warp op.
1479 SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1480 for (int i = 0; i < distributedType.getRank(); ++i)
1481 newDistributedShape[i + extractOp.getNumIndices()] =
1482 distributedType.getDimSize(i);
1483 auto newDistributedType =
1484 VectorType::get(newDistributedShape, distributedType.getElementType());
1485 SmallVector<size_t> newRetIndices;
1486 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1487 rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1488 newRetIndices);
1489 rewriter.setInsertionPointAfter(newWarpOp);
1490 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1491 // Extract from distributed vector.
1492 Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1493 extractOp.getMixedPosition());
1494 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1495 newExtract);
1496 return success();
1497 }
1498};
1499
1500/// Pattern to move out vector.extract with a scalar result.
1501/// Only supports 1-D and 0-D sources for now.
1502struct WarpOpExtractScalar : public WarpDistributionPattern {
1503 WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1504 PatternBenefit b = 1)
1505 : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1506 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1507 PatternRewriter &rewriter) const override {
1508 OpOperand *operand =
1509 getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1510 if (!operand)
1511 return failure();
1512 unsigned int operandNumber = operand->getOperandNumber();
1513 auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1514 VectorType extractSrcType = extractOp.getSourceVectorType();
1515 // Only supports 1-D or 0-D sources for now.
1516 if (extractSrcType.getRank() > 1) {
1517 return rewriter.notifyMatchFailure(
1518 extractOp, "only 0-D or 1-D source supported for now");
1519 }
1520 // TODO: Supported shuffle types should be parameterizable, similar to
1521 // `WarpShuffleFromIdxFn`.
1522 if (!extractSrcType.getElementType().isF32() &&
1523 !extractSrcType.getElementType().isInteger(32))
1524 return rewriter.notifyMatchFailure(
1525 extractOp, "only f32/i32 element types are supported");
1526 bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1527 Type elType = extractSrcType.getElementType();
1528 VectorType distributedVecType;
1529 if (!is0dOrVec1Extract) {
1530 assert(extractSrcType.getRank() == 1 &&
1531 "expected that extract src rank is 0 or 1");
1532 if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1533 return failure();
1534 int64_t elementsPerLane =
1535 extractSrcType.getShape()[0] / warpOp.getWarpSize();
1536 distributedVecType = VectorType::get({elementsPerLane}, elType);
1537 } else {
1538 distributedVecType = extractSrcType;
1539 }
1540 // Yield source vector and position (if present) from warp op.
1541 SmallVector<Value> additionalResults{extractOp.getSource()};
1542 SmallVector<Type> additionalResultTypes{distributedVecType};
1543 additionalResults.append(
1544 SmallVector<Value>(extractOp.getDynamicPosition()));
1545 additionalResultTypes.append(
1546 SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1547
1548 Location loc = extractOp.getLoc();
1549 SmallVector<size_t> newRetIndices;
1550 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1551 rewriter, warpOp, additionalResults, additionalResultTypes,
1552 newRetIndices);
1553 rewriter.setInsertionPointAfter(newWarpOp);
1554 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1555
1556 // 0d extract: The new warp op broadcasts the source vector to all lanes.
1557 // All lanes extract the scalar.
1558 if (is0dOrVec1Extract) {
1559 Value newExtract;
1560 SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1561 newExtract =
1562 vector::ExtractOp::create(rewriter, loc, distributedVec, indices);
1563 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1564 newExtract);
1565 return success();
1566 }
1567
1568 int64_t staticPos = extractOp.getStaticPosition()[0];
1569 OpFoldResult pos = ShapedType::isDynamic(staticPos)
1570 ? (newWarpOp->getResult(newRetIndices[1]))
1571 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1572 // 1d extract: Distribute the source vector. One lane extracts and shuffles
1573 // the value to all other lanes.
1574 int64_t elementsPerLane = distributedVecType.getShape()[0];
1575 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1576 // tid of extracting thread: pos / elementsPerLane
1577 Value broadcastFromTid = affine::makeComposedAffineApply(
1578 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1579 // Extract at position: pos % elementsPerLane
1580 Value newPos =
1581 elementsPerLane == 1
1582 ? arith::ConstantIndexOp::create(rewriter, loc, 0).getResult()
1583 : affine::makeComposedAffineApply(rewriter, loc,
1584 sym0 % elementsPerLane, pos);
1585 Value extracted =
1586 vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1587
1588 // Shuffle the extracted value to all lanes.
1589 Value shuffled = warpShuffleFromIdxFn(
1590 loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1591 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1592 return success();
1593 }
1594
1595private:
1596 WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1597};
1598
1599/// Pattern to move out vector.insert with a scalar input.
1600/// Only supports 1-D and 0-D destinations for now.
1601struct WarpOpInsertScalar : public WarpDistributionPattern {
1602 using Base::Base;
1603 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1604 PatternRewriter &rewriter) const override {
1605 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1606 if (!operand)
1607 return failure();
1608 unsigned int operandNumber = operand->getOperandNumber();
1609 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1610 VectorType vecType = insertOp.getDestVectorType();
1611 VectorType distrType =
1612 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1613
1614 // Only supports 1-D or 0-D destinations for now.
1615 if (vecType.getRank() > 1) {
1616 return rewriter.notifyMatchFailure(
1617 insertOp, "only 0-D or 1-D source supported for now");
1618 }
1619
1620 // Yield destination vector, source scalar and position from warp op.
1621 SmallVector<Value> additionalResults{insertOp.getDest(),
1622 insertOp.getValueToStore()};
1623 SmallVector<Type> additionalResultTypes{
1624 distrType, insertOp.getValueToStore().getType()};
1625 additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1626 additionalResultTypes.append(
1627 SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1628
1629 Location loc = insertOp.getLoc();
1630 SmallVector<size_t> newRetIndices;
1631 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1632 rewriter, warpOp, additionalResults, additionalResultTypes,
1633 newRetIndices);
1634 rewriter.setInsertionPointAfter(newWarpOp);
1635 Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1636 Value newSource = newWarpOp->getResult(newRetIndices[1]);
1637 rewriter.setInsertionPointAfter(newWarpOp);
1638
1639 OpFoldResult pos;
1640 if (vecType.getRank() != 0) {
1641 int64_t staticPos = insertOp.getStaticPosition()[0];
1642 pos = ShapedType::isDynamic(staticPos)
1643 ? (newWarpOp->getResult(newRetIndices[2]))
1644 : OpFoldResult(rewriter.getIndexAttr(staticPos));
1645 }
1646
1647 // This condition is always true for 0-d vectors.
1648 if (vecType == distrType) {
1649 Value newInsert;
1650 SmallVector<OpFoldResult> indices;
1651 if (pos) {
1652 indices.push_back(pos);
1653 }
1654 newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1655 distributedVec, indices);
1656 // Broadcast: Simply move the vector.insert op out.
1657 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1658 newInsert);
1659 return success();
1660 }
1661
1662 // This is a distribution. Only one lane should insert.
1663 int64_t elementsPerLane = distrType.getShape()[0];
1664 AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1665 // tid of extracting thread: pos / elementsPerLane
1666 Value insertingLane = affine::makeComposedAffineApply(
1667 rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1668 // Insert position: pos % elementsPerLane
1669 OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
1670 rewriter, loc, sym0 % elementsPerLane, pos);
1671 Value isInsertingLane =
1672 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1673 newWarpOp.getLaneid(), insertingLane);
1674 Value newResult =
1675 scf::IfOp::create(
1676 rewriter, loc, isInsertingLane,
1677 /*thenBuilder=*/
1678 [&](OpBuilder &builder, Location loc) {
1679 Value newInsert = vector::InsertOp::create(
1680 builder, loc, newSource, distributedVec, newPos);
1681 scf::YieldOp::create(builder, loc, newInsert);
1682 },
1683 /*elseBuilder=*/
1684 [&](OpBuilder &builder, Location loc) {
1685 scf::YieldOp::create(builder, loc, distributedVec);
1686 })
1687 .getResult(0);
1688 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1689 return success();
1690 }
1691};
1692
1693struct WarpOpInsert : public WarpDistributionPattern {
1694 using Base::Base;
1695 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1696 PatternRewriter &rewriter) const override {
1697 OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1698 if (!operand)
1699 return failure();
1700 unsigned int operandNumber = operand->getOperandNumber();
1701 auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1702 Location loc = insertOp.getLoc();
1703
1704 // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1705 if (insertOp.getDestVectorType().getRank() <= 1) {
1706 return failure();
1707 }
1708
1709 // All following cases are 2d or higher dimensional source vectors.
1710
1711 if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1712 // There is no distribution, this is a broadcast. Simply move the insert
1713 // out of the warp op.
1714 SmallVector<size_t> newRetIndices;
1715 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1716 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1717 {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1718 newRetIndices);
1719 rewriter.setInsertionPointAfter(newWarpOp);
1720 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1721 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1722 Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1723 distributedDest,
1724 insertOp.getMixedPosition());
1725 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1726 newResult);
1727 return success();
1728 }
1729
1730 // Find the distributed dimension. There should be exactly one.
1731 auto distrDestType =
1732 cast<VectorType>(warpOp.getResult(operandNumber).getType());
1733 auto yieldedType = cast<VectorType>(operand->get().getType());
1734 int64_t distrDestDim = -1;
1735 for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1736 if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1737 // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1738 // support distributing multiple dimensions in the future.
1739 assert(distrDestDim == -1 && "found multiple distributed dims");
1740 distrDestDim = i;
1741 }
1742 }
1743 assert(distrDestDim != -1 && "could not find distributed dimension");
1744
1745 // Compute the distributed source vector type.
1746 VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1747 SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1748 // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1749 // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1750 // insert a smaller vector<3xf32>.
1751 // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1752 // case, one lane will insert the source vector<96xf32>. The other
1753 // lanes will not do anything.
1754 int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1755 if (distrSrcDim >= 0)
1756 distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1757 auto distrSrcType =
1758 VectorType::get(distrSrcShape, distrDestType.getElementType());
1759
1760 // Yield source and dest vectors from warp op.
1761 SmallVector<size_t> newRetIndices;
1762 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1763 rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1764 {distrSrcType, distrDestType}, newRetIndices);
1765 rewriter.setInsertionPointAfter(newWarpOp);
1766 Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1767 Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1768
1769 // Insert into the distributed vector.
1770 Value newResult;
1771 if (distrSrcDim >= 0) {
1772 // Every lane inserts a small piece.
1773 newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1774 distributedDest,
1775 insertOp.getMixedPosition());
1776 } else {
1777 // One lane inserts the entire source vector.
1778 int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1779 SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1780 SmallVector<int64_t> newPos = getAsIntegers(pos);
1781 // tid of inserting lane: pos / elementsPerLane
1782 Value insertingLane = arith::ConstantIndexOp::create(
1783 rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1784 Value isInsertingLane =
1785 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1786 newWarpOp.getLaneid(), insertingLane);
1787 // Insert position: pos % elementsPerLane
1788 newPos[distrDestDim] %= elementsPerLane;
1789 auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1790 Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1791 distributedDest, newPos);
1792 scf::YieldOp::create(builder, loc, newInsert);
1793 };
1794 auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1795 scf::YieldOp::create(builder, loc, distributedDest);
1796 };
1797 newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1798 /*thenBuilder=*/insertingBuilder,
1799 /*elseBuilder=*/nonInsertingBuilder)
1800 .getResult(0);
1801 }
1802
1803 rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1804 return success();
1805 }
1806};
1807
1808/// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
1809/// the scf.if is the last operation in the region so that it doesn't
1810/// change the order of execution. This creates a new scf.if after the
1811/// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
1812/// the "inner" WarpExecuteOnLane0Op. Example:
1813/// ```
1814/// gpu.warp_execute_on_lane_0(%laneid)[32] {
1815/// %payload = ... : vector<32xindex>
1816/// scf.if %pred {
1817/// vector.store %payload, %buffer[%idx] : memref<128xindex>,
1818/// vector<32xindex>
1819/// }
1820/// gpu.yield
1821/// }
1822/// ```
1823/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
1824/// %payload = ... : vector<32xindex>
1825/// gpu.yield %payload : vector<32xindex>
1826/// }
1827/// scf.if %pred {
1828/// gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
1829/// ^bb0(%arg1: vector<32xindex>):
1830/// vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
1831/// }
1832/// }
1833/// ```
1834struct WarpOpScfIfOp : public WarpDistributionPattern {
1835 WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1836 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1837 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1838 PatternRewriter &rewriter) const override {
1839 gpu::YieldOp warpOpYield = warpOp.getTerminator();
1840 // Only pick up `IfOp` if it is the last op in the region.
1841 Operation *lastNode = warpOpYield->getPrevNode();
1842 auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1843 if (!ifOp)
1844 return failure();
1845
1846 // The current `WarpOp` can yield two types of values:
1847 // 1. Not results of `IfOp`:
1848 // Preserve them in the new `WarpOp`.
1849 // Collect their yield index to remap the usages.
1850 // 2. Results of `IfOp`:
1851 // They are not part of the new `WarpOp` results.
1852 // Map current warp's yield operand index to `IfOp` result idx.
1853 SmallVector<Value> nonIfYieldValues;
1854 SmallVector<unsigned> nonIfYieldIndices;
1855 llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1856 llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1857 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1858 const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
1859 if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
1860 nonIfYieldValues.push_back(yieldOperand.get());
1861 nonIfYieldIndices.push_back(yieldOperandIdx);
1862 continue;
1863 }
1864 OpResult ifResult = cast<OpResult>(yieldOperand.get());
1865 const unsigned ifResultIdx = ifResult.getResultNumber();
1866 ifResultMapping[yieldOperandIdx] = ifResultIdx;
1867 // If this `ifOp` result is vector type and it is yielded by the
1868 // `WarpOp`, we keep track the distributed type for this result.
1869 if (!isa<VectorType>(ifResult.getType()))
1870 continue;
1871 VectorType distType =
1872 cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1873 ifResultDistTypes[ifResultIdx] = distType;
1874 }
1875
1876 // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
1877 // them
1878 auto [escapingValuesThen, escapingValueInputTypesThen,
1879 escapingValueDistTypesThen] =
1880 getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1881 distributionMapFn);
1882 auto [escapingValuesElse, escapingValueInputTypesElse,
1883 escapingValueDistTypesElse] =
1884 getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1885 distributionMapFn);
1886 if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1887 llvm::is_contained(escapingValueDistTypesElse, Type{}))
1888 return failure();
1889
1890 // The new `WarpOp` groups yields values in following order:
1891 // 1. Branch condition
1892 // 2. Escaping values then branch
1893 // 3. Escaping values else branch
1894 // 4. All non-`ifOp` yielded values.
1895 SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1896 newWarpOpYieldValues.append(escapingValuesThen.begin(),
1897 escapingValuesThen.end());
1898 newWarpOpYieldValues.append(escapingValuesElse.begin(),
1899 escapingValuesElse.end());
1900 SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1901 newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1902 escapingValueDistTypesThen.end());
1903 newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1904 escapingValueDistTypesElse.end());
1905
1906 for (auto [idx, val] :
1907 llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1908 newWarpOpYieldValues.push_back(val);
1909 newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1910 }
1911 // Replace the old `WarpOp` with the new one that has additional yield
1912 // values and types.
1913 SmallVector<size_t> newIndices;
1914 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1915 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
1916 // `ifOp` returns the result of the inner warp op.
1917 SmallVector<Type> newIfOpDistResTypes;
1918 for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1919 Type distType = cast<Value>(res).getType();
1920 if (auto vecType = dyn_cast<VectorType>(distType)) {
1921 AffineMap map = distributionMapFn(cast<Value>(res));
1922 // Fallback to affine map if the dist result was not previously recorded
1923 distType = ifResultDistTypes.count(i)
1924 ? ifResultDistTypes[i]
1925 : getDistributedType(
1926 vecType, map,
1927 map.isEmpty() ? 1 : newWarpOp.getWarpSize());
1928 }
1929 newIfOpDistResTypes.push_back(distType);
1930 }
1931 // Create a new `IfOp` outside the new `WarpOp` region.
1932 OpBuilder::InsertionGuard g(rewriter);
1933 rewriter.setInsertionPointAfter(newWarpOp);
1934 auto newIfOp = scf::IfOp::create(
1935 rewriter, ifOp.getLoc(), newIfOpDistResTypes,
1936 newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
1937 static_cast<bool>(ifOp.elseBlock()));
1938 auto encloseRegionInWarpOp =
1939 [&](Block *oldIfBranch, Block *newIfBranch,
1940 llvm::SmallSetVector<Value, 32> &escapingValues,
1941 SmallVector<Type> &escapingValueInputTypes,
1942 size_t warpResRangeStart) {
1943 OpBuilder::InsertionGuard g(rewriter);
1944 if (!newIfBranch)
1945 return;
1946 rewriter.setInsertionPointToStart(newIfBranch);
1947 llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1948 SmallVector<Value> innerWarpInputVals;
1949 SmallVector<Type> innerWarpInputTypes;
1950 for (size_t i = 0; i < escapingValues.size();
1951 ++i, ++warpResRangeStart) {
1952 innerWarpInputVals.push_back(
1953 newWarpOp.getResult(newIndices[warpResRangeStart]));
1954 escapeValToBlockArgIndex[escapingValues[i]] =
1955 innerWarpInputTypes.size();
1956 innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1957 }
1958 auto innerWarp = WarpExecuteOnLane0Op::create(
1959 rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1960 newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1961 innerWarpInputVals, innerWarpInputTypes);
1962
1963 innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
1964 innerWarp.getWarpRegion().addArguments(
1965 innerWarpInputTypes,
1966 SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1967
1968 SmallVector<Value> yieldOperands;
1969 for (Value operand : oldIfBranch->getTerminator()->getOperands())
1970 yieldOperands.push_back(operand);
1971 rewriter.eraseOp(oldIfBranch->getTerminator());
1972
1973 rewriter.setInsertionPointToEnd(innerWarp.getBody());
1974 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1975 rewriter.setInsertionPointAfter(innerWarp);
1976 scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1977
1978 // Update any users of escaping values that were forwarded to the
1979 // inner `WarpOp`. These values are arguments of the inner `WarpOp`.
1980 innerWarp.walk([&](Operation *op) {
1981 for (OpOperand &operand : op->getOpOperands()) {
1982 auto it = escapeValToBlockArgIndex.find(operand.get());
1983 if (it == escapeValToBlockArgIndex.end())
1984 continue;
1985 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1986 }
1987 });
1988 mlir::vector::moveScalarUniformCode(innerWarp);
1989 };
1990 encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
1991 &newIfOp.getThenRegion().front(), escapingValuesThen,
1992 escapingValueInputTypesThen, 1);
1993 if (!ifOp.getElseRegion().empty())
1994 encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
1995 &newIfOp.getElseRegion().front(),
1996 escapingValuesElse, escapingValueInputTypesElse,
1997 1 + escapingValuesThen.size());
1998 // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
1999 // result.
2000 for (auto [origIdx, newIdx] : ifResultMapping)
2001 rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
2002 newIfOp.getResult(newIdx), newIfOp);
2003 return success();
2004 }
2005
2006private:
2007 DistributionMapFn distributionMapFn;
2008};
2009
2010/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
2011/// the scf.ForOp is the last operation in the region so that it doesn't
2012/// change the order of execution. This creates a new scf.for region after the
2013/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
2014/// WarpExecuteOnLane0Op region. Example:
2015/// ```
2016/// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
2017/// ...
2018/// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
2019/// -> (vector<128xf32>) {
2020/// ...
2021/// scf.yield %r : vector<128xf32>
2022/// }
2023/// gpu.yield %v1 : vector<128xf32>
2024/// }
2025/// ```
2026/// To:
2027/// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
2028/// ...
2029/// gpu.yield %v : vector<128xf32>
2030/// }
2031/// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
2032/// -> (vector<4xf32>) {
2033/// %iw = gpu.warp_execute_on_lane_0(%laneid)
2034/// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
2035/// ^bb0(%arg: vector<128xf32>):
2036/// ...
2037/// gpu.yield %ir : vector<128xf32>
2038/// }
2039/// scf.yield %iw : vector<4xf32>
2040/// }
2041/// ```
2042struct WarpOpScfForOp : public WarpDistributionPattern {
2043
2044 WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
2045 : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
2046 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2047 PatternRewriter &rewriter) const override {
2048 gpu::YieldOp warpOpYield = warpOp.getTerminator();
2049 // Only pick up `ForOp` if it is the last op in the region.
2050 Operation *lastNode = warpOpYield->getPrevNode();
2051 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
2052 if (!forOp)
2053 return failure();
2054 // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
2055 // Those Values need to be returned by the new warp op.
2056 auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2057 getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2058 distributionMapFn);
2059 if (llvm::is_contained(escapingValueDistTypes, Type{}))
2060 return failure();
2061 // `WarpOp` can yield two types of values:
2062 // 1. Values that are not results of the `ForOp`:
2063 // These values must also be yielded by the new `WarpOp`. Also, we need
2064 // to record the index mapping for these values to replace them later.
2065 // 2. Values that are results of the `ForOp`:
2066 // In this case, we record the index mapping between the `WarpOp` result
2067 // index and matching `ForOp` result index.
2068 // Additionally, we keep track of the distributed types for all `ForOp`
2069 // vector results.
2070 SmallVector<Value> nonForYieldedValues;
2071 SmallVector<unsigned> nonForResultIndices;
2072 llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
2073 llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
2074 for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
2075 // Yielded value is not a result of the forOp.
2076 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
2077 nonForYieldedValues.push_back(yieldOperand.get());
2078 nonForResultIndices.push_back(yieldOperand.getOperandNumber());
2079 continue;
2080 }
2081 OpResult forResult = cast<OpResult>(yieldOperand.get());
2082 unsigned int forResultNumber = forResult.getResultNumber();
2083 forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber;
2084 // If this `ForOp` result is vector type and it is yielded by the
2085 // `WarpOp`, we keep track the distributed type for this result.
2086 if (!isa<VectorType>(forResult.getType()))
2087 continue;
2088 VectorType distType = cast<VectorType>(
2089 warpOp.getResult(yieldOperand.getOperandNumber()).getType());
2090 forResultDistTypes[forResultNumber] = distType;
2091 }
2092
2093 // Newly created `WarpOp` will yield values in following order:
2094 // 1. Loop bounds.
2095 // 2. All init args of the `ForOp`.
2096 // 3. All escaping values.
2097 // 4. All non-`ForOp` yielded values.
2098 SmallVector<Value> newWarpOpYieldValues;
2099 SmallVector<Type> newWarpOpDistTypes;
2100 newWarpOpYieldValues.insert(
2101 newWarpOpYieldValues.end(),
2102 {forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()});
2103 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2104 {forOp.getLowerBound().getType(),
2105 forOp.getUpperBound().getType(),
2106 forOp.getStep().getType()});
2107 for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
2108 newWarpOpYieldValues.push_back(initArg);
2109 // Compute the distributed type for this init arg.
2110 Type distType = initArg.getType();
2111 if (auto vecType = dyn_cast<VectorType>(distType)) {
2112 // If the `ForOp` result corresponds to this init arg is already yielded
2113 // we can get the distributed type from `forResultDistTypes` map.
2114 // Otherwise, we compute it using distributionMapFn.
2115 AffineMap map = distributionMapFn(initArg);
2116 distType =
2117 forResultDistTypes.count(i)
2118 ? forResultDistTypes[i]
2119 : getDistributedType(vecType, map,
2120 map.isEmpty() ? 1 : warpOp.getWarpSize());
2121 }
2122 newWarpOpDistTypes.push_back(distType);
2123 }
2124 // Insert escaping values and their distributed types.
2125 newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
2126 escapingValues.begin(), escapingValues.end());
2127 newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2128 escapingValueDistTypes.begin(),
2129 escapingValueDistTypes.end());
2130 // Next, we insert all non-`ForOp` yielded values and their distributed
2131 // types.
2132 for (auto [i, v] :
2133 llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2134 newWarpOpYieldValues.push_back(v);
2135 newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
2136 }
2137 // Create the new `WarpOp` with the updated yield values and types.
2138 SmallVector<size_t> newIndices;
2139 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2140 rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
2141
2142 // Next, we create a new `ForOp` with the init args yielded by the new
2143 // `WarpOp`.
2144 const unsigned initArgsStartIdx = 3; // After loop bounds.
2145 const unsigned escapingValuesStartIdx =
2146 initArgsStartIdx +
2147 forOp.getInitArgs().size(); // `ForOp` init args are positioned before
2148 // escaping values in the new `WarpOp`.
2149 SmallVector<Value> newForOpOperands;
2150 for (size_t i = initArgsStartIdx; i < escapingValuesStartIdx; ++i)
2151 newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
2152
2153 // Create a new `ForOp` outside the new `WarpOp` region.
2154 OpBuilder::InsertionGuard g(rewriter);
2155 rewriter.setInsertionPointAfter(newWarpOp);
2156 auto newForOp = scf::ForOp::create(
2157 rewriter, forOp.getLoc(),
2158 /**LowerBound=**/ newWarpOp.getResult(newIndices[0]),
2159 /**UpperBound=**/ newWarpOp.getResult(newIndices[1]),
2160 /**Step=**/ newWarpOp.getResult(newIndices[2]), newForOpOperands,
2161 /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
2162 // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
2163 // newly created `ForOp`. This `WarpOp` will contain all ops that were
2164 // contained within the original `ForOp` body.
2165 rewriter.setInsertionPointToStart(newForOp.getBody());
2166
2167 SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
2168 newForOp.getRegionIterArgs().end());
2169 SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
2170 forOp.getResultTypes().end());
2171 // Escaping values are forwarded to the inner `WarpOp` as its (additional)
2172 // arguments. We keep track of the mapping between these values and their
2173 // argument index in the inner `WarpOp` (to replace users later).
2174 llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
2175 for (size_t i = escapingValuesStartIdx;
2176 i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2177 innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
2178 argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
2179 innerWarpInputType.size();
2180 innerWarpInputType.push_back(
2181 escapingValueInputTypes[i - escapingValuesStartIdx]);
2182 }
2183 // Create the inner `WarpOp` with the new input values and types.
2184 auto innerWarp = WarpExecuteOnLane0Op::create(
2185 rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
2186 newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
2187 innerWarpInputType);
2188
2189 // Inline the `ForOp` body into the inner `WarpOp` body.
2190 SmallVector<Value> argMapping;
2191 argMapping.push_back(newForOp.getInductionVar());
2192 for (Value args : innerWarp.getBody()->getArguments())
2193 argMapping.push_back(args);
2194
2195 argMapping.resize(forOp.getBody()->getNumArguments());
2196 SmallVector<Value> yieldOperands;
2197 for (Value operand : forOp.getBody()->getTerminator()->getOperands())
2198 yieldOperands.push_back(operand);
2199
2200 rewriter.eraseOp(forOp.getBody()->getTerminator());
2201 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
2202
2203 // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
2204 // original `ForOp` results.
2205 rewriter.setInsertionPointToEnd(innerWarp.getBody());
2206 gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
2207 rewriter.setInsertionPointAfter(innerWarp);
2208 // Insert a scf.yield op at the end of the new `ForOp` body that yields
2209 // the inner `WarpOp` results.
2210 if (!innerWarp.getResults().empty())
2211 scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
2212
2213 // Update the users of the new `WarpOp` results that were coming from the
2214 // original `ForOp` to the corresponding new `ForOp` result.
2215 for (auto [origIdx, newIdx] : forResultMapping)
2216 rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
2217 newForOp.getResult(newIdx), newForOp);
2218 // Update any users of escaping values that were forwarded to the
2219 // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
2220 newForOp.walk([&](Operation *op) {
2221 for (OpOperand &operand : op->getOpOperands()) {
2222 auto it = argIndexMapping.find(operand.get());
2223 if (it == argIndexMapping.end())
2224 continue;
2225 operand.set(innerWarp.getBodyRegion().getArgument(it->second));
2226 }
2227 });
2228
2229 // Finally, hoist out any now uniform code from the inner `WarpOp`.
2230 mlir::vector::moveScalarUniformCode(innerWarp);
2231 return success();
2232 }
2233
2234private:
2235 DistributionMapFn distributionMapFn;
2236};
2237
2238/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
2239/// The vector is reduced in parallel. Currently limited to vector size
2240/// matching the warpOp size. E.g.:
2241/// ```
2242/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
2243/// %0 = "some_def"() : () -> (vector<32xf32>)
2244/// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
2245/// gpu.yield %1 : f32
2246/// }
2247/// ```
2248/// is lowered to:
2249/// ```
2250/// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
2251/// %1 = "some_def"() : () -> (vector<32xf32>)
2252/// gpu.yield %1 : vector<32xf32>
2253/// }
2254/// %a = vector.extract %0[0] : f32 from vector<1xf32>
2255/// %r = ("warp.reduction %a")
2256/// ```
2257struct WarpOpReduction : public WarpDistributionPattern {
2258 WarpOpReduction(MLIRContext *context,
2259 DistributedReductionFn distributedReductionFn,
2260 PatternBenefit benefit = 1)
2261 : WarpDistributionPattern(context, benefit),
2262 distributedReductionFn(std::move(distributedReductionFn)) {}
2263
2264 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2265 PatternRewriter &rewriter) const override {
2266 OpOperand *yieldOperand =
2267 getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
2268 if (!yieldOperand)
2269 return failure();
2270
2271 auto reductionOp =
2272 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
2273 auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
2274 // Only rank 1 vectors supported.
2275 if (vectorType.getRank() != 1)
2276 return rewriter.notifyMatchFailure(
2277 warpOp, "Only rank 1 reductions can be distributed.");
2278 // Only warp_size-sized vectors supported.
2279 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2280 return rewriter.notifyMatchFailure(
2281 warpOp, "Reduction vector dimension must match was size.");
2282 if (!reductionOp.getType().isIntOrFloat())
2283 return rewriter.notifyMatchFailure(
2284 warpOp, "Reduction distribution currently only supports floats and "
2285 "integer types.");
2286
2287 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2288 // Return vector that will be reduced from the WarpExecuteOnLane0Op.
2289 unsigned operandIndex = yieldOperand->getOperandNumber();
2290 SmallVector<Value> yieldValues = {reductionOp.getVector()};
2291 SmallVector<Type> retTypes = {
2292 VectorType::get({numElements}, reductionOp.getType())};
2293 if (reductionOp.getAcc()) {
2294 yieldValues.push_back(reductionOp.getAcc());
2295 retTypes.push_back(reductionOp.getAcc().getType());
2296 }
2297 SmallVector<size_t> newRetIndices;
2298 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2299 rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2300 rewriter.setInsertionPointAfter(newWarpOp);
2301
2302 // Obtain data to reduce for a single lane.
2303 Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2304 // Distribute and reduce across threads.
2305 Value fullReduce =
2306 distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2307 reductionOp.getKind(), newWarpOp.getWarpSize());
2308 if (reductionOp.getAcc()) {
2309 fullReduce = vector::makeArithReduction(
2310 rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2311 newWarpOp.getResult(newRetIndices[1]));
2312 }
2313 rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
2314 return success();
2315 }
2316
2317private:
2318 DistributedReductionFn distributedReductionFn;
2319};
2320
2321} // namespace
2322
2324 RewritePatternSet &patterns,
2326 patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
2327}
2328
2329void mlir::vector::populateDistributeTransferWriteOpPatterns(
2330 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2331 unsigned maxNumElementsToExtract, PatternBenefit benefit) {
2332 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
2333 maxNumElementsToExtract, benefit);
2334}
2335
2336void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2337 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2338 const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
2339 PatternBenefit readBenefit) {
2340 patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
2341 patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2342 WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
2343 WarpOpConstant, WarpOpInsertScalar, WarpOpInsert,
2344 WarpOpCreateMask<vector::CreateMaskOp>,
2345 WarpOpCreateMask<vector::ConstantMaskOp>,
2346 WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2347 patterns.getContext(), benefit);
2348 patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
2349 benefit);
2350 patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
2351 benefit);
2352 patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
2353 benefit);
2354}
2355
2356void mlir::vector::populateDistributeReduction(
2357 RewritePatternSet &patterns,
2358 const DistributedReductionFn &distributedReductionFn,
2359 PatternBenefit benefit) {
2360 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
2361 benefit);
2362}
2363
2364/// Helper to know if an op can be hoisted out of the region.
2365static bool canBeHoisted(Operation *op,
2366 function_ref<bool(Value)> definedOutside) {
2367 return llvm::all_of(op->getOperands(), definedOutside) &&
2368 isMemoryEffectFree(op) && op->getNumRegions() == 0;
2369}
2370
2371void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2372 Block *body = warpOp.getBody();
2373
2374 // Keep track of the ops we want to hoist.
2375 llvm::SmallSetVector<Operation *, 8> opsToMove;
2376
2377 // Helper to check if a value is or will be defined outside of the region.
2378 auto isDefinedOutsideOfBody = [&](Value value) {
2379 auto *definingOp = value.getDefiningOp();
2380 return (definingOp && opsToMove.count(definingOp)) ||
2381 warpOp.isDefinedOutsideOfRegion(value);
2382 };
2383
2384 // Do not use walk here, as we do not want to go into nested regions and hoist
2385 // operations from there.
2386 for (auto &op : body->without_terminator()) {
2387 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
2388 return isa<VectorType>(result.getType());
2389 });
2390 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
2391 opsToMove.insert(&op);
2392 }
2393
2394 // Move all the ops marked as uniform outside of the region.
2395 for (Operation *op : opsToMove)
2396 op->moveBefore(warpOp);
2397}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isEmpty() const
Returns true if this affine map is an empty map, i.e., () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:376
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:541
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:878
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:703
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:412
unsigned getNumOperands()
Definition Operation.h:375
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:444
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
This represents an operation in an abstracted form, suitable for use with the builder APIs.
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, SmallVector< size_t > &indices) const
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef< int64_t > originalShape, ArrayRef< int64_t > distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl< Value > &delinearizedIds) const
Delinearize the given laneId into multiple dimensions, where each dimension's size is determined by o...
WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) const
Helper to create a new WarpExecuteOnLane0Op with different signature.
virtual LogicalResult matchAndRewrite(WarpExecuteOnLane0Op op, PatternRewriter &rewriter) const override=0
OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, llvm::function_ref< bool(Operation *)> fn) const
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.