MLIR  21.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/AffineExpr.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include <utility>
23 
24 using namespace mlir;
25 using namespace mlir::vector;
26 using namespace mlir::gpu;
27 
28 /// Currently the distribution map is implicit based on the vector shape. In the
29 /// future it will be part of the op.
30 /// Example:
31 /// ```
32 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
33 /// ...
34 /// gpu.yield %3 : vector<32x16x64xf32>
35 /// }
36 /// ```
37 /// Would have an implicit map of:
38 /// `(d0, d1, d2) -> (d0, d2)`
39 static AffineMap calculateImplicitMap(VectorType sequentialType,
40  VectorType distributedType) {
42  perm.reserve(1);
43  // Check which dimensions of the sequential type are different than the
44  // dimensions of the distributed type to know the distributed dimensions. Then
45  // associate each distributed dimension to an ID in order.
46  for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
47  if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
48  perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
49  }
50  auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
51  distributedType.getContext());
52  return map;
53 }
54 
55 namespace {
56 
57 /// Helper struct to create the load / store operations that permit transit
58 /// through the parallel / sequential and the sequential / parallel boundaries
59 /// when performing `rewriteWarpOpToScfFor`.
60 ///
61 /// The vector distribution dimension is inferred from the vector types.
62 struct DistributedLoadStoreHelper {
63  DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
64  Value laneId, Value zero)
65  : sequentialVal(sequentialVal), distributedVal(distributedVal),
66  laneId(laneId), zero(zero) {
67  sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
68  distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
69  if (sequentialVectorType && distributedVectorType)
70  distributionMap =
71  calculateImplicitMap(sequentialVectorType, distributedVectorType);
72  }
73 
74  Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
75  int64_t distributedSize = distributedVectorType.getDimSize(index);
77  return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
78  ArrayRef<Value>{laneId});
79  }
80 
81  /// Create a store during the process of distributing the
82  /// `vector.warp_execute_on_thread_0` op.
83  /// Vector distribution assumes the following convention regarding the
84  /// temporary buffers that are created to transition values. This **must**
85  /// be properly specified in the `options.warpAllocationFn`:
86  /// 1. scalars of type T transit through a memref<1xT>.
87  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
88  Operation *buildStore(RewriterBase &b, Location loc, Value val,
89  Value buffer) {
90  assert((val == distributedVal || val == sequentialVal) &&
91  "Must store either the preregistered distributed or the "
92  "preregistered sequential value.");
93  // Scalar case can directly use memref.store.
94  if (!isa<VectorType>(val.getType()))
95  return b.create<memref::StoreOp>(loc, val, buffer, zero);
96 
97  // Vector case must use vector::TransferWriteOp which will later lower to
98  // vector.store of memref.store depending on further lowerings.
99  int64_t rank = sequentialVectorType.getRank();
100  SmallVector<Value> indices(rank, zero);
101  if (val == distributedVal) {
102  for (auto dimExpr : distributionMap.getResults()) {
103  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
104  indices[index] = buildDistributedOffset(b, loc, index);
105  }
106  }
107  SmallVector<bool> inBounds(indices.size(), true);
108  return b.create<vector::TransferWriteOp>(
109  loc, val, buffer, indices,
110  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
111  }
112 
113  /// Create a load during the process of distributing the
114  /// `vector.warp_execute_on_thread_0` op.
115  /// Vector distribution assumes the following convention regarding the
116  /// temporary buffers that are created to transition values. This **must**
117  /// be properly specified in the `options.warpAllocationFn`:
118  /// 1. scalars of type T transit through a memref<1xT>.
119  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
120  ///
121  /// When broadcastMode is true, the load is not distributed to account for
122  /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
123  ///
124  /// Example:
125  ///
126  /// ```
127  /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
128  /// gpu.yield %cst : f32
129  /// }
130  /// // Both types are f32. The constant %cst is broadcasted to all lanes.
131  /// ```
132  /// This behavior described in more detail in the documentation of the op.
133  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
134 
135  // Scalar case can directly use memref.store.
136  if (!isa<VectorType>(type))
137  return b.create<memref::LoadOp>(loc, buffer, zero);
138 
139  // Other cases must be vector atm.
140  // Vector case must use vector::TransferReadOp which will later lower to
141  // vector.read of memref.read depending on further lowerings.
142  assert((type == distributedVectorType || type == sequentialVectorType) &&
143  "Must store either the preregistered distributed or the "
144  "preregistered sequential type.");
145  SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
146  if (type == distributedVectorType) {
147  for (auto dimExpr : distributionMap.getResults()) {
148  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
149  indices[index] = buildDistributedOffset(b, loc, index);
150  }
151  }
152  SmallVector<bool> inBounds(indices.size(), true);
153  return b.create<vector::TransferReadOp>(
154  loc, cast<VectorType>(type), buffer, indices,
155  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
156  }
157 
158  Value sequentialVal, distributedVal, laneId, zero;
159  VectorType sequentialVectorType, distributedVectorType;
160  AffineMap distributionMap;
161 };
162 
163 } // namespace
164 
165 // Clones `op` into a new operation that takes `operands` and returns
166 // `resultTypes`.
168  Location loc, Operation *op,
169  ArrayRef<Value> operands,
170  ArrayRef<Type> resultTypes) {
171  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
172  op->getAttrs());
173  return rewriter.create(res);
174 }
175 
176 namespace {
177 
178 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
179 /// thread `laneId` executes the entirety of the computation.
180 ///
181 /// After the transformation:
182 /// - the IR within the scf.if op can be thought of as executing sequentially
183 /// (from the point of view of threads along `laneId`).
184 /// - the IR outside of the scf.if op can be thought of as executing in
185 /// parallel (from the point of view of threads along `laneId`).
186 ///
187 /// Values that need to transit through the parallel / sequential and the
188 /// sequential / parallel boundaries do so via reads and writes to a temporary
189 /// memory location.
190 ///
191 /// The transformation proceeds in multiple steps:
192 /// 1. Create the scf.if op.
193 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
194 /// within the scf.if to transit the values captured from above.
195 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
196 /// consistent within the scf.if.
197 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
198 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to
199 /// transit the values returned by the op.
200 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
201 /// consistent after the scf.if.
202 /// 7. Perform late cleanups.
203 ///
204 /// All this assumes the vector distribution occurs along the most minor
205 /// distributed vector dimension.
206 struct WarpOpToScfIfPattern : public WarpDistributionPattern {
207  WarpOpToScfIfPattern(MLIRContext *context,
209  PatternBenefit benefit = 1)
210  : WarpDistributionPattern(context, benefit), options(options) {}
211 
212  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
213  PatternRewriter &rewriter) const override {
214  assert(warpOp.getBodyRegion().hasOneBlock() &&
215  "expected WarpOp with single block");
216  Block *warpOpBody = &warpOp.getBodyRegion().front();
217  Location loc = warpOp.getLoc();
218 
219  // Passed all checks. Start rewriting.
220  OpBuilder::InsertionGuard g(rewriter);
221  rewriter.setInsertionPoint(warpOp);
222 
223  // Step 1: Create scf.if op.
224  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
225  Value isLane0 = rewriter.create<arith::CmpIOp>(
226  loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
227  auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
228  /*withElseRegion=*/false);
229  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
230 
231  // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
232  // reads within the scf.if to transit the values captured from above.
233  SmallVector<Value> bbArgReplacements;
234  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
235  Value sequentialVal = warpOpBody->getArgument(it.index());
236  Value distributedVal = it.value();
237  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
238  warpOp.getLaneid(), c0);
239 
240  // Create buffer before the ifOp.
241  rewriter.setInsertionPoint(ifOp);
242  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
243  sequentialVal.getType());
244  // Store distributed vector into buffer, before the ifOp.
245  helper.buildStore(rewriter, loc, distributedVal, buffer);
246  // Load sequential vector from buffer, inside the ifOp.
247  rewriter.setInsertionPointToStart(ifOp.thenBlock());
248  bbArgReplacements.push_back(
249  helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
250  }
251 
252  // Step 3. Insert sync after all the stores and before all the loads.
253  if (!warpOp.getArgs().empty()) {
254  rewriter.setInsertionPoint(ifOp);
255  options.warpSyncronizationFn(loc, rewriter, warpOp);
256  }
257 
258  // Step 4. Move body of warpOp to ifOp.
259  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
260 
261  // Step 5. Insert appropriate writes within scf.if and reads after the
262  // scf.if to transit the values returned by the op.
263  // TODO: at this point, we can reuse the shared memory from previous
264  // buffers.
265  SmallVector<Value> replacements;
266  auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
267  Location yieldLoc = yieldOp.getLoc();
268  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
269  Value sequentialVal = it.value();
270  Value distributedVal = warpOp->getResult(it.index());
271  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
272  warpOp.getLaneid(), c0);
273 
274  // Create buffer before the ifOp.
275  rewriter.setInsertionPoint(ifOp);
276  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
277  sequentialVal.getType());
278 
279  // Store yielded value into buffer, inside the ifOp, before the
280  // terminator.
281  rewriter.setInsertionPoint(yieldOp);
282  helper.buildStore(rewriter, loc, sequentialVal, buffer);
283 
284  // Load distributed value from buffer, after the warpOp.
285  rewriter.setInsertionPointAfter(ifOp);
286  // Result type and yielded value type are the same. This is a broadcast.
287  // E.g.:
288  // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
289  // gpu.yield %cst : f32
290  // }
291  // Both types are f32. The constant %cst is broadcasted to all lanes.
292  // This is described in more detail in the documentation of the op.
293  replacements.push_back(
294  helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
295  }
296 
297  // Step 6. Insert sync after all the stores and before all the loads.
298  if (!yieldOp.getOperands().empty()) {
299  rewriter.setInsertionPointAfter(ifOp);
300  options.warpSyncronizationFn(loc, rewriter, warpOp);
301  }
302 
303  // Step 7. Delete terminator and add empty scf.yield.
304  rewriter.eraseOp(yieldOp);
305  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
306  rewriter.create<scf::YieldOp>(yieldLoc);
307 
308  // Compute replacements for WarpOp results.
309  rewriter.replaceOp(warpOp, replacements);
310 
311  return success();
312  }
313 
314 private:
316 };
317 
318 /// Return the distributed vector type based on the original type and the
319 /// distribution map. The map is expected to have a dimension equal to the
320 /// original type rank and should be a projection where the results are the
321 /// distributed dimensions. The number of results should be equal to the number
322 /// of warp sizes which is currently limited to 1.
323 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
324 /// and a warp size of 16 would distribute the second dimension (associated to
325 /// d1) and return vector<16x2x64>
326 static VectorType getDistributedType(VectorType originalType, AffineMap map,
327  int64_t warpSize) {
328  SmallVector<int64_t> targetShape(originalType.getShape());
329  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
330  unsigned position = map.getDimPosition(i);
331  if (targetShape[position] % warpSize != 0) {
332  if (warpSize % targetShape[position] != 0) {
333  return VectorType();
334  }
335  warpSize /= targetShape[position];
336  targetShape[position] = 1;
337  continue;
338  }
339  targetShape[position] = targetShape[position] / warpSize;
340  warpSize = 1;
341  break;
342  }
343  if (warpSize != 1) {
344  return VectorType();
345  }
346  VectorType targetType =
347  VectorType::get(targetShape, originalType.getElementType());
348  return targetType;
349 }
350 
351 /// Distribute transfer_write ops based on the affine map returned by
352 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
353 /// will not be distributed (it should be less than the warp size).
354 ///
355 /// Example:
356 /// ```
357 /// %0 = gpu.warp_execute_on_lane_0(%id){
358 /// ...
359 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
360 /// gpu.yield
361 /// }
362 /// ```
363 /// To
364 /// ```
365 /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
366 /// ...
367 /// gpu.yield %v : vector<32xf32>
368 /// }
369 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
370 struct WarpOpTransferWrite : public WarpDistributionPattern {
371  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
372  unsigned maxNumElementsToExtract, PatternBenefit b = 1)
373  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
374  maxNumElementsToExtract(maxNumElementsToExtract) {}
375 
376  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
377  /// are multiples of the distribution ratio are supported at the moment.
378  LogicalResult tryDistributeOp(RewriterBase &rewriter,
379  vector::TransferWriteOp writeOp,
380  WarpExecuteOnLane0Op warpOp) const {
381  VectorType writtenVectorType = writeOp.getVectorType();
382 
383  // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
384  // to separate it from the rest.
385  if (writtenVectorType.getRank() == 0)
386  return failure();
387 
388  // 2. Compute the distributed type.
389  AffineMap map = distributionMapFn(writeOp.getVector());
390  VectorType targetType =
391  getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
392  if (!targetType)
393  return failure();
394 
395  // 2.5 Compute the distributed type for the new mask;
396  VectorType maskType;
397  if (writeOp.getMask()) {
398  // TODO: Distribution of masked writes with non-trivial permutation maps
399  // requires the distribution of the mask to elementwise match the
400  // distribution of the permuted written vector. Currently the details
401  // of which lane is responsible for which element is captured strictly
402  // by shape information on the warp op, and thus requires materializing
403  // the permutation in IR.
404  if (!writeOp.getPermutationMap().isMinorIdentity())
405  return failure();
406  maskType =
407  getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
408  }
409 
410  // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
411  // the rest.
412  vector::TransferWriteOp newWriteOp =
413  cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
414 
415  // 4. Reindex the write using the distribution map.
416  auto newWarpOp =
417  newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
418 
419  // Delinearize the lane id based on the way threads are divided across the
420  // vector. To get the number of threads per vector dimension, divide the
421  // sequential size by the distributed size along each dim.
422  rewriter.setInsertionPoint(newWriteOp);
423  SmallVector<OpFoldResult> delinearizedIdSizes;
424  for (auto [seqSize, distSize] :
425  llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
426  assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
427  delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
428  }
429  SmallVector<Value> delinearized;
430  if (map.getNumResults() > 1) {
431  delinearized = rewriter
432  .create<mlir::affine::AffineDelinearizeIndexOp>(
433  newWarpOp.getLoc(), newWarpOp.getLaneid(),
434  delinearizedIdSizes)
435  .getResults();
436  } else {
437  // If there is only one map result, we can elide the delinearization
438  // op and use the lane id directly.
439  delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
440  }
441 
442  AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
443  Location loc = newWriteOp.getLoc();
444  SmallVector<Value> indices(newWriteOp.getIndices().begin(),
445  newWriteOp.getIndices().end());
446  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
447  AffineExpr d0, d1;
448  bindDims(newWarpOp.getContext(), d0, d1);
449  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
450  if (!indexExpr)
451  continue;
452  unsigned indexPos = indexExpr.getPosition();
453  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
454  Value laneId = delinearized[vectorPos];
455  auto scale =
456  rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
457  indices[indexPos] = affine::makeComposedAffineApply(
458  rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
459  }
460  newWriteOp.getIndicesMutable().assign(indices);
461 
462  return success();
463  }
464 
465  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
466  LogicalResult tryExtractOp(RewriterBase &rewriter,
467  vector::TransferWriteOp writeOp,
468  WarpExecuteOnLane0Op warpOp) const {
469  Location loc = writeOp.getLoc();
470  VectorType vecType = writeOp.getVectorType();
471 
472  if (vecType.getNumElements() > maxNumElementsToExtract) {
473  return rewriter.notifyMatchFailure(
474  warpOp,
475  llvm::formatv(
476  "writes more elements ({0}) than allowed to extract ({1})",
477  vecType.getNumElements(), maxNumElementsToExtract));
478  }
479 
480  // Do not process warp ops that contain only TransferWriteOps.
481  if (llvm::all_of(warpOp.getOps(),
482  llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
483  return failure();
484 
485  SmallVector<Value> yieldValues = {writeOp.getVector()};
486  SmallVector<Type> retTypes = {vecType};
487  SmallVector<size_t> newRetIndices;
488  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
489  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
490  rewriter.setInsertionPointAfter(newWarpOp);
491 
492  // Create a second warp op that contains only writeOp.
493  auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
494  loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
495  Block &body = secondWarpOp.getBodyRegion().front();
496  rewriter.setInsertionPointToStart(&body);
497  auto newWriteOp =
498  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
499  newWriteOp.getValueToStoreMutable().assign(
500  newWarpOp.getResult(newRetIndices[0]));
501  rewriter.eraseOp(writeOp);
502  rewriter.create<gpu::YieldOp>(newWarpOp.getLoc());
503  return success();
504  }
505 
506  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
507  PatternRewriter &rewriter) const override {
508  auto yield = cast<gpu::YieldOp>(
509  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
510  Operation *lastNode = yield->getPrevNode();
511  auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
512  if (!writeOp)
513  return failure();
514 
515  Value maybeMask = writeOp.getMask();
516  if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
517  return writeOp.getVector() == value ||
518  (maybeMask && maybeMask == value) ||
519  warpOp.isDefinedOutsideOfRegion(value);
520  }))
521  return failure();
522 
523  if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
524  return success();
525 
526  // Masked writes not supported for extraction.
527  if (writeOp.getMask())
528  return failure();
529 
530  if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
531  return success();
532 
533  return failure();
534  }
535 
536 private:
537  /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
538  /// execute op with the proper return type. The new write op is updated to
539  /// write the result of the new warp execute op. The old `writeOp` is deleted.
540  vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
541  WarpExecuteOnLane0Op warpOp,
542  vector::TransferWriteOp writeOp,
543  VectorType targetType,
544  VectorType maybeMaskType) const {
545  assert(writeOp->getParentOp() == warpOp &&
546  "write must be nested immediately under warp");
547  OpBuilder::InsertionGuard g(rewriter);
548  SmallVector<size_t> newRetIndices;
549  WarpExecuteOnLane0Op newWarpOp;
550  if (maybeMaskType) {
551  newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
552  rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
553  TypeRange{targetType, maybeMaskType}, newRetIndices);
554  } else {
555  newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
556  rewriter, warpOp, ValueRange{{writeOp.getVector()}},
557  TypeRange{targetType}, newRetIndices);
558  }
559  rewriter.setInsertionPointAfter(newWarpOp);
560  auto newWriteOp =
561  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
562  rewriter.eraseOp(writeOp);
563  newWriteOp.getValueToStoreMutable().assign(
564  newWarpOp.getResult(newRetIndices[0]));
565  if (maybeMaskType)
566  newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
567  return newWriteOp;
568  }
569 
570  DistributionMapFn distributionMapFn;
571  unsigned maxNumElementsToExtract = 1;
572 };
573 
574 /// Sink out elementwise op feeding into a warp op yield.
575 /// ```
576 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
577 /// ...
578 /// %3 = arith.addf %1, %2 : vector<32xf32>
579 /// gpu.yield %3 : vector<32xf32>
580 /// }
581 /// ```
582 /// To
583 /// ```
584 /// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
585 /// vector<1xf32>, vector<1xf32>) {
586 /// ...
587 /// %4 = arith.addf %2, %3 : vector<32xf32>
588 /// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
589 /// vector<32xf32>
590 /// }
591 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
592 struct WarpOpElementwise : public WarpDistributionPattern {
593  using Base::Base;
594  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
595  PatternRewriter &rewriter) const override {
596  OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
598  });
599  if (!yieldOperand)
600  return failure();
601 
602  Operation *elementWise = yieldOperand->get().getDefiningOp();
603  unsigned operandIndex = yieldOperand->getOperandNumber();
604  Value distributedVal = warpOp.getResult(operandIndex);
605  SmallVector<Value> yieldValues;
606  SmallVector<Type> retTypes;
607  Location loc = warpOp.getLoc();
608  for (OpOperand &operand : elementWise->getOpOperands()) {
609  Type targetType;
610  if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
611  // If the result type is a vector, the operands must also be vectors.
612  auto operandType = cast<VectorType>(operand.get().getType());
613  targetType =
614  VectorType::get(vecType.getShape(), operandType.getElementType());
615  } else {
616  auto operandType = operand.get().getType();
617  assert(!isa<VectorType>(operandType) &&
618  "unexpected yield of vector from op with scalar result type");
619  targetType = operandType;
620  }
621  retTypes.push_back(targetType);
622  yieldValues.push_back(operand.get());
623  }
624  SmallVector<size_t> newRetIndices;
625  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
626  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
627  rewriter.setInsertionPointAfter(newWarpOp);
628  SmallVector<Value> newOperands(elementWise->getOperands().begin(),
629  elementWise->getOperands().end());
630  for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
631  newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
632  }
633  OpBuilder::InsertionGuard g(rewriter);
634  rewriter.setInsertionPointAfter(newWarpOp);
636  rewriter, loc, elementWise, newOperands,
637  {newWarpOp.getResult(operandIndex).getType()});
638  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
639  newOp->getResult(0));
640  return success();
641  }
642 };
643 
644 /// Sink out splat constant op feeding into a warp op yield.
645 /// ```
646 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
647 /// ...
648 /// %cst = arith.constant dense<2.0> : vector<32xf32>
649 /// gpu.yield %cst : vector<32xf32>
650 /// }
651 /// ```
652 /// To
653 /// ```
654 /// gpu.warp_execute_on_lane_0(%arg0 {
655 /// ...
656 /// }
657 /// %0 = arith.constant dense<2.0> : vector<1xf32>
658 struct WarpOpConstant : public WarpDistributionPattern {
659  using Base::Base;
660  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
661  PatternRewriter &rewriter) const override {
662  OpOperand *yieldOperand =
663  getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
664  if (!yieldOperand)
665  return failure();
666  auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
667  auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
668  if (!dense)
669  return failure();
670  // Notify the rewriter that the warp op is changing (see the comment on
671  // the WarpOpTransferRead pattern).
672  rewriter.startOpModification(warpOp);
673  unsigned operandIndex = yieldOperand->getOperandNumber();
674  Attribute scalarAttr = dense.getSplatValue<Attribute>();
675  auto newAttr = DenseElementsAttr::get(
676  cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
677  Location loc = warpOp.getLoc();
678  rewriter.setInsertionPointAfter(warpOp);
679  Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
680  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
681  rewriter.finalizeOpModification(warpOp);
682  return success();
683  }
684 };
685 
686 /// Sink out transfer_read op feeding into a warp op yield.
687 /// ```
688 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
689 /// ...
690 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
691 // vector<32xf32>
692 /// gpu.yield %2 : vector<32xf32>
693 /// }
694 /// ```
695 /// To
696 /// ```
697 /// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
698 /// vector<1xf32>, vector<1xf32>) {
699 /// ...
700 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
701 /// vector<32xf32> gpu.yield %2 : vector<32xf32>
702 /// }
703 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
704 struct WarpOpTransferRead : public WarpDistributionPattern {
705  using Base::Base;
706  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
707  PatternRewriter &rewriter) const override {
708  // Try to find a distributable yielded read. Note that this pattern can
709  // still fail at the end after distribution, in which case this might have
710  // missed another distributable read.
711  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
712  // Don't duplicate transfer_read ops when distributing.
713  return isa<vector::TransferReadOp>(op) && op->hasOneUse();
714  });
715  if (!operand)
716  return rewriter.notifyMatchFailure(
717  warpOp, "warp result is not a vector.transfer_read op");
718  auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
719 
720  // Source must be defined outside of the region.
721  if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
722  return rewriter.notifyMatchFailure(
723  read, "source must be defined outside of the region");
724 
725  unsigned operandIndex = operand->getOperandNumber();
726  Value distributedVal = warpOp.getResult(operandIndex);
727 
728  SmallVector<Value, 4> indices(read.getIndices().begin(),
729  read.getIndices().end());
730  auto sequentialType = cast<VectorType>(read.getResult().getType());
731  auto distributedType = cast<VectorType>(distributedVal.getType());
732  AffineMap map = calculateImplicitMap(sequentialType, distributedType);
733  AffineMap indexMap = map.compose(read.getPermutationMap());
734 
735  // Try to delinearize the lane ID to match the rank expected for
736  // distribution.
737  SmallVector<Value> delinearizedIds;
738  if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
739  distributedType.getShape(), warpOp.getWarpSize(),
740  warpOp.getLaneid(), delinearizedIds)) {
741  return rewriter.notifyMatchFailure(
742  read, "cannot delinearize lane ID for distribution");
743  }
744  assert(!delinearizedIds.empty() || map.getNumResults() == 0);
745 
746  // Distribute indices and the mask (if present).
747  OpBuilder::InsertionGuard g(rewriter);
748  SmallVector<Value> additionalResults(indices.begin(), indices.end());
749  SmallVector<Type> additionalResultTypes(indices.size(),
750  rewriter.getIndexType());
751  additionalResults.push_back(read.getPadding());
752  additionalResultTypes.push_back(read.getPadding().getType());
753 
754  bool hasMask = false;
755  if (read.getMask()) {
756  hasMask = true;
757  // TODO: Distribution of masked reads with non-trivial permutation maps
758  // requires the distribution of the mask to elementwise match the
759  // distribution of the permuted written vector. Currently the details
760  // of which lane is responsible for which element is captured strictly
761  // by shape information on the warp op, and thus requires materializing
762  // the permutation in IR.
763  if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
764  return rewriter.notifyMatchFailure(
765  read, "non-trivial permutation maps not supported");
766  VectorType maskType =
767  getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
768  additionalResults.push_back(read.getMask());
769  additionalResultTypes.push_back(maskType);
770  }
771 
772  SmallVector<size_t> newRetIndices;
773  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
774  rewriter, warpOp, additionalResults, additionalResultTypes,
775  newRetIndices);
776  distributedVal = newWarpOp.getResult(operandIndex);
777 
778  // Distributed indices were appended first.
779  SmallVector<Value> newIndices;
780  for (int64_t i = 0, e = indices.size(); i < e; ++i)
781  newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
782 
783  rewriter.setInsertionPointAfter(newWarpOp);
784  for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
785  AffineExpr d0, d1;
786  bindDims(read.getContext(), d0, d1);
787  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
788  if (!indexExpr)
789  continue;
790  unsigned indexPos = indexExpr.getPosition();
791  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
792  int64_t scale = distributedType.getDimSize(vectorPos);
793  newIndices[indexPos] = affine::makeComposedAffineApply(
794  rewriter, read.getLoc(), d0 + scale * d1,
795  {newIndices[indexPos], delinearizedIds[vectorPos]});
796  }
797 
798  // Distributed padding value was appended right after the indices.
799  Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
800  // Distributed mask value was added at the end (if the op has a mask).
801  Value newMask =
802  hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
803  : Value();
804  auto newRead = rewriter.create<vector::TransferReadOp>(
805  read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
806  read.getPermutationMapAttr(), newPadding, newMask,
807  read.getInBoundsAttr());
808 
809  rewriter.replaceAllUsesWith(distributedVal, newRead);
810  return success();
811  }
812 };
813 
814 /// Remove any result that has no use along with the matching yieldOp operand.
815 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
816 struct WarpOpDeadResult : public WarpDistributionPattern {
817  using Base::Base;
818  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
819  PatternRewriter &rewriter) const override {
820  SmallVector<Type> newResultTypes;
821  newResultTypes.reserve(warpOp->getNumResults());
822  SmallVector<Value> newYieldValues;
823  newYieldValues.reserve(warpOp->getNumResults());
824  DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
825  DenseMap<OpResult, int64_t> dedupResultPositionMap;
826  auto yield = cast<gpu::YieldOp>(
827  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
828 
829  // Some values may be yielded multiple times and correspond to multiple
830  // results. Deduplicating occurs by taking each result with its matching
831  // yielded value, and:
832  // 1. recording the unique first position at which the value is yielded.
833  // 2. recording for the result, the first position at which the dedup'ed
834  // value is yielded.
835  // 3. skipping from the new result types / new yielded values any result
836  // that has no use or whose yielded value has already been seen.
837  for (OpResult result : warpOp.getResults()) {
838  Value yieldOperand = yield.getOperand(result.getResultNumber());
839  auto it = dedupYieldOperandPositionMap.insert(
840  std::make_pair(yieldOperand, newResultTypes.size()));
841  dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
842  if (result.use_empty() || !it.second)
843  continue;
844  newResultTypes.push_back(result.getType());
845  newYieldValues.push_back(yieldOperand);
846  }
847  // No modification, exit early.
848  if (yield.getNumOperands() == newYieldValues.size())
849  return failure();
850  // Move the body of the old warpOp to a new warpOp.
851  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
852  rewriter, warpOp, newYieldValues, newResultTypes);
853 
854  // Simplify the new warp op after dropping dead results.
855  newWarpOp.getBody()->walk([&](Operation *op) {
856  if (isOpTriviallyDead(op))
857  rewriter.eraseOp(op);
858  });
859 
860  // Replace results of the old warpOp by the new, deduplicated results.
861  SmallVector<Value> newValues;
862  newValues.reserve(warpOp->getNumResults());
863  for (OpResult result : warpOp.getResults()) {
864  if (result.use_empty())
865  newValues.push_back(Value());
866  else
867  newValues.push_back(
868  newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
869  }
870  rewriter.replaceOp(warpOp, newValues);
871  return success();
872  }
873 };
874 
875 // If an operand is directly yielded out of the region we can forward it
876 // directly and it doesn't need to go through the region.
877 struct WarpOpForwardOperand : public WarpDistributionPattern {
878  using Base::Base;
879  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
880  PatternRewriter &rewriter) const override {
881  SmallVector<Type> resultTypes;
882  SmallVector<Value> yieldValues;
883  auto yield = cast<gpu::YieldOp>(
884  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
885  Value valForwarded;
886  unsigned resultIndex;
887  for (OpOperand &operand : yield->getOpOperands()) {
888  Value result = warpOp.getResult(operand.getOperandNumber());
889  if (result.use_empty())
890  continue;
891 
892  // Assume all the values coming from above are uniform.
893  if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
894  if (result.getType() != operand.get().getType())
895  continue;
896  valForwarded = operand.get();
897  resultIndex = operand.getOperandNumber();
898  break;
899  }
900  auto arg = dyn_cast<BlockArgument>(operand.get());
901  if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
902  continue;
903  Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
904  if (result.getType() != warpOperand.getType())
905  continue;
906  valForwarded = warpOperand;
907  resultIndex = operand.getOperandNumber();
908  break;
909  }
910  if (!valForwarded)
911  return failure();
912  // Notify the rewriter that the warp op is changing (see the comment on
913  // the WarpOpTransferRead pattern).
914  rewriter.startOpModification(warpOp);
915  rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
916  rewriter.finalizeOpModification(warpOp);
917  return success();
918  }
919 };
920 
921 struct WarpOpBroadcast : public WarpDistributionPattern {
922  using Base::Base;
923  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
924  PatternRewriter &rewriter) const override {
925  OpOperand *operand =
926  getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
927  if (!operand)
928  return failure();
929  unsigned int operandNumber = operand->getOperandNumber();
930  auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
931  Location loc = broadcastOp.getLoc();
932  auto destVecType =
933  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
934  Value broadcastSrc = broadcastOp.getSource();
935  Type broadcastSrcType = broadcastSrc.getType();
936 
937  // Check that the broadcast actually spans a set of values uniformly across
938  // all threads. In other words, check that each thread can reconstruct
939  // their own broadcast.
940  // For that we simply check that the broadcast we want to build makes sense.
941  if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
943  return failure();
944  SmallVector<size_t> newRetIndices;
945  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
946  rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
947  rewriter.setInsertionPointAfter(newWarpOp);
948  Value broadcasted = rewriter.create<vector::BroadcastOp>(
949  loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
950  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
951  broadcasted);
952  return success();
953  }
954 };
955 
956 /// Pattern to move shape cast out of the warp op. shape cast is basically a
957 /// no-op for warp distribution; we need to handle the shape though.
958 struct WarpOpShapeCast : public WarpDistributionPattern {
959  using Base::Base;
960  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
961  PatternRewriter &rewriter) const override {
962  OpOperand *operand =
963  getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
964  if (!operand)
965  return failure();
966 
967  auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
968 
969  unsigned int operandNumber = operand->getOperandNumber();
970  auto castDistributedType =
971  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
972  VectorType castOriginalType = oldCastOp.getSourceVectorType();
973  VectorType castResultType = castDistributedType;
974 
975  // We expect the distributed type to have a smaller rank than the original
976  // type. Prepend with size-one dimensions to make them the same.
977  unsigned castDistributedRank = castDistributedType.getRank();
978  unsigned castOriginalRank = castOriginalType.getRank();
979  if (castDistributedRank < castOriginalRank) {
980  SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
981  llvm::append_range(shape, castDistributedType.getShape());
982  castDistributedType =
983  VectorType::get(shape, castDistributedType.getElementType());
984  }
985 
986  SmallVector<size_t> newRetIndices;
987  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
988  rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
989  newRetIndices);
990  rewriter.setInsertionPointAfter(newWarpOp);
991  Value newCast = rewriter.create<vector::ShapeCastOp>(
992  oldCastOp.getLoc(), castResultType,
993  newWarpOp->getResult(newRetIndices[0]));
994  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
995  return success();
996  }
997 };
998 
999 /// Sink out vector.create_mask op feeding into a warp op yield.
1000 /// ```
1001 /// %0 = ...
1002 /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1003 /// ...
1004 /// %mask = vector.create_mask %0 : vector<32xi1>
1005 /// gpu.yield %mask : vector<32xi1>
1006 /// }
1007 /// ```
1008 /// To
1009 /// ```
1010 /// %0 = ...
1011 /// gpu.warp_execute_on_lane_0(%arg0) {
1012 /// ...
1013 /// }
1014 /// %cmp = arith.cmpi ult, %laneid, %0
1015 /// %ub = arith.select %cmp, %c0, %c1
1016 /// %1 = vector.create_mask %ub : vector<1xi1>
1017 struct WarpOpCreateMask : public WarpDistributionPattern {
1018  using Base::Base;
1019  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1020  PatternRewriter &rewriter) const override {
1021  OpOperand *yieldOperand =
1022  getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1023  if (!yieldOperand)
1024  return failure();
1025 
1026  auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1027 
1028  // Early exit if any values needed for calculating the new mask indices
1029  // are defined inside the warp op.
1030  if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1031  return warpOp.isDefinedOutsideOfRegion(value);
1032  }))
1033  return failure();
1034 
1035  Location loc = mask.getLoc();
1036  unsigned operandIndex = yieldOperand->getOperandNumber();
1037 
1038  auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1039  VectorType seqType = mask.getVectorType();
1040  ArrayRef<int64_t> seqShape = seqType.getShape();
1041  ArrayRef<int64_t> distShape = distType.getShape();
1042 
1043  rewriter.setInsertionPointAfter(warpOp);
1044 
1045  // Delinearize the lane ID for constructing the distributed mask sizes.
1046  SmallVector<Value> delinearizedIds;
1047  if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1048  warpOp.getWarpSize(), warpOp.getLaneid(),
1049  delinearizedIds))
1050  return rewriter.notifyMatchFailure(
1051  mask, "cannot delinearize lane ID for distribution");
1052  assert(!delinearizedIds.empty());
1053 
1054  // Notify the rewriter that the warp op is changing (see the comment on
1055  // the WarpOpTransferRead pattern).
1056  rewriter.startOpModification(warpOp);
1057 
1058  AffineExpr s0, s1;
1059  bindSymbols(rewriter.getContext(), s0, s1);
1060  SmallVector<Value> newOperands;
1061  for (int i = 0, e = distShape.size(); i < e; ++i) {
1062  // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1063  // find the distance from the largest mask index owned by this lane to the
1064  // original mask size. `vector.create_mask` implicitly clamps mask
1065  // operands to the range [0, mask_vector_size[i]], or in other words, the
1066  // mask sizes are always in the range [0, mask_vector_size[i]).
1068  rewriter, loc, s1 - s0 * distShape[i],
1069  {delinearizedIds[i], mask.getOperand(i)});
1070  newOperands.push_back(maskDimIdx);
1071  }
1072 
1073  auto newMask =
1074  rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1075  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1076  rewriter.finalizeOpModification(warpOp);
1077  return success();
1078  }
1079 };
1080 
1081 /// Pattern to move out vector.extract of single element vector. Those don't
1082 /// need to be distributed and can just be propagated outside of the region.
1083 struct WarpOpExtract : public WarpDistributionPattern {
1084  using Base::Base;
1085  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1086  PatternRewriter &rewriter) const override {
1087  OpOperand *operand =
1088  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1089  if (!operand)
1090  return failure();
1091  unsigned int operandNumber = operand->getOperandNumber();
1092  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1093  VectorType extractSrcType = extractOp.getSourceVectorType();
1094  Location loc = extractOp.getLoc();
1095 
1096  // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1097  if (extractSrcType.getRank() <= 1) {
1098  return failure();
1099  }
1100 
1101  // All following cases are 2d or higher dimensional source vectors.
1102 
1103  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1104  // There is no distribution, this is a broadcast. Simply move the extract
1105  // out of the warp op.
1106  // TODO: This could be optimized. E.g., in case of a scalar result, let
1107  // one lane extract and shuffle the result to all other lanes (same as
1108  // the 1d case).
1109  SmallVector<size_t> newRetIndices;
1110  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1111  rewriter, warpOp, {extractOp.getVector()},
1112  {extractOp.getSourceVectorType()}, newRetIndices);
1113  rewriter.setInsertionPointAfter(newWarpOp);
1114  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1115  // Extract from distributed vector.
1116  Value newExtract = rewriter.create<vector::ExtractOp>(
1117  loc, distributedVec, extractOp.getMixedPosition());
1118  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1119  newExtract);
1120  return success();
1121  }
1122 
1123  // Find the distributed dimension. There should be exactly one.
1124  auto distributedType =
1125  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1126  auto yieldedType = cast<VectorType>(operand->get().getType());
1127  int64_t distributedDim = -1;
1128  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1129  if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1130  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1131  // support distributing multiple dimensions in the future.
1132  assert(distributedDim == -1 && "found multiple distributed dims");
1133  distributedDim = i;
1134  }
1135  }
1136  assert(distributedDim != -1 && "could not find distributed dimension");
1137  (void)distributedDim;
1138 
1139  // Yield source vector from warp op.
1140  SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1141  for (int i = 0; i < distributedType.getRank(); ++i)
1142  newDistributedShape[i + extractOp.getNumIndices()] =
1143  distributedType.getDimSize(i);
1144  auto newDistributedType =
1145  VectorType::get(newDistributedShape, distributedType.getElementType());
1146  SmallVector<size_t> newRetIndices;
1147  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1148  rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1149  newRetIndices);
1150  rewriter.setInsertionPointAfter(newWarpOp);
1151  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1152  // Extract from distributed vector.
1153  Value newExtract = rewriter.create<vector::ExtractOp>(
1154  loc, distributedVec, extractOp.getMixedPosition());
1155  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1156  newExtract);
1157  return success();
1158  }
1159 };
1160 
1161 /// Pattern to move out vector.extract with a scalar result.
1162 /// Only supports 1-D and 0-D sources for now.
1163 struct WarpOpExtractScalar : public WarpDistributionPattern {
1164  WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1165  PatternBenefit b = 1)
1166  : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1167  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1168  PatternRewriter &rewriter) const override {
1169  OpOperand *operand =
1170  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1171  if (!operand)
1172  return failure();
1173  unsigned int operandNumber = operand->getOperandNumber();
1174  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1175  VectorType extractSrcType = extractOp.getSourceVectorType();
1176  // Only supports 1-D or 0-D sources for now.
1177  if (extractSrcType.getRank() > 1) {
1178  return rewriter.notifyMatchFailure(
1179  extractOp, "only 0-D or 1-D source supported for now");
1180  }
1181  // TODO: Supported shuffle types should be parameterizable, similar to
1182  // `WarpShuffleFromIdxFn`.
1183  if (!extractSrcType.getElementType().isF32() &&
1184  !extractSrcType.getElementType().isInteger(32))
1185  return rewriter.notifyMatchFailure(
1186  extractOp, "only f32/i32 element types are supported");
1187  bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1188  Type elType = extractSrcType.getElementType();
1189  VectorType distributedVecType;
1190  if (!is0dOrVec1Extract) {
1191  assert(extractSrcType.getRank() == 1 &&
1192  "expected that extract src rank is 0 or 1");
1193  if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1194  return failure();
1195  int64_t elementsPerLane =
1196  extractSrcType.getShape()[0] / warpOp.getWarpSize();
1197  distributedVecType = VectorType::get({elementsPerLane}, elType);
1198  } else {
1199  distributedVecType = extractSrcType;
1200  }
1201  // Yield source vector and position (if present) from warp op.
1202  SmallVector<Value> additionalResults{extractOp.getVector()};
1203  SmallVector<Type> additionalResultTypes{distributedVecType};
1204  additionalResults.append(
1205  SmallVector<Value>(extractOp.getDynamicPosition()));
1206  additionalResultTypes.append(
1207  SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1208 
1209  Location loc = extractOp.getLoc();
1210  SmallVector<size_t> newRetIndices;
1211  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1212  rewriter, warpOp, additionalResults, additionalResultTypes,
1213  newRetIndices);
1214  rewriter.setInsertionPointAfter(newWarpOp);
1215  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1216 
1217  // 0d extract: The new warp op broadcasts the source vector to all lanes.
1218  // All lanes extract the scalar.
1219  if (is0dOrVec1Extract) {
1220  Value newExtract;
1221  SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1222  newExtract =
1223  rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
1224  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1225  newExtract);
1226  return success();
1227  }
1228 
1229  int64_t staticPos = extractOp.getStaticPosition()[0];
1230  OpFoldResult pos = ShapedType::isDynamic(staticPos)
1231  ? (newWarpOp->getResult(newRetIndices[1]))
1232  : OpFoldResult(rewriter.getIndexAttr(staticPos));
1233  // 1d extract: Distribute the source vector. One lane extracts and shuffles
1234  // the value to all other lanes.
1235  int64_t elementsPerLane = distributedVecType.getShape()[0];
1236  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1237  // tid of extracting thread: pos / elementsPerLane
1238  Value broadcastFromTid = affine::makeComposedAffineApply(
1239  rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1240  // Extract at position: pos % elementsPerLane
1241  Value newPos =
1242  elementsPerLane == 1
1243  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1244  : affine::makeComposedAffineApply(rewriter, loc,
1245  sym0 % elementsPerLane, pos);
1246  Value extracted =
1247  rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
1248 
1249  // Shuffle the extracted value to all lanes.
1250  Value shuffled = warpShuffleFromIdxFn(
1251  loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1252  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1253  return success();
1254  }
1255 
1256 private:
1257  WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1258 };
1259 
1260 /// Pattern to convert vector.extractelement to vector.extract.
1261 struct WarpOpExtractElement : public WarpDistributionPattern {
1262  using Base::Base;
1263  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1264  PatternRewriter &rewriter) const override {
1265  OpOperand *operand =
1266  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1267  if (!operand)
1268  return failure();
1269  auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1270  SmallVector<OpFoldResult> indices;
1271  if (auto pos = extractOp.getPosition()) {
1272  indices.push_back(pos);
1273  }
1274  rewriter.setInsertionPoint(extractOp);
1275  rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1276  extractOp, extractOp.getVector(), indices);
1277  return success();
1278  }
1279 };
1280 
1281 /// Pattern to move out vector.insert with a scalar input.
1282 /// Only supports 1-D and 0-D destinations for now.
1283 struct WarpOpInsertScalar : public WarpDistributionPattern {
1284  using Base::Base;
1285  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1286  PatternRewriter &rewriter) const override {
1287  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1288  if (!operand)
1289  return failure();
1290  unsigned int operandNumber = operand->getOperandNumber();
1291  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1292  VectorType vecType = insertOp.getDestVectorType();
1293  VectorType distrType =
1294  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1295 
1296  // Only supports 1-D or 0-D destinations for now.
1297  if (vecType.getRank() > 1) {
1298  return rewriter.notifyMatchFailure(
1299  insertOp, "only 0-D or 1-D source supported for now");
1300  }
1301 
1302  // Yield destination vector, source scalar and position from warp op.
1303  SmallVector<Value> additionalResults{insertOp.getDest(),
1304  insertOp.getValueToStore()};
1305  SmallVector<Type> additionalResultTypes{
1306  distrType, insertOp.getValueToStore().getType()};
1307  additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1308  additionalResultTypes.append(
1309  SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1310 
1311  Location loc = insertOp.getLoc();
1312  SmallVector<size_t> newRetIndices;
1313  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1314  rewriter, warpOp, additionalResults, additionalResultTypes,
1315  newRetIndices);
1316  rewriter.setInsertionPointAfter(newWarpOp);
1317  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1318  Value newSource = newWarpOp->getResult(newRetIndices[1]);
1319  rewriter.setInsertionPointAfter(newWarpOp);
1320 
1321  OpFoldResult pos;
1322  if (vecType.getRank() != 0) {
1323  int64_t staticPos = insertOp.getStaticPosition()[0];
1324  pos = ShapedType::isDynamic(staticPos)
1325  ? (newWarpOp->getResult(newRetIndices[2]))
1326  : OpFoldResult(rewriter.getIndexAttr(staticPos));
1327  }
1328 
1329  // This condition is always true for 0-d vectors.
1330  if (vecType == distrType) {
1331  Value newInsert;
1332  SmallVector<OpFoldResult> indices;
1333  if (pos) {
1334  indices.push_back(pos);
1335  }
1336  newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
1337  distributedVec, indices);
1338  // Broadcast: Simply move the vector.insert op out.
1339  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1340  newInsert);
1341  return success();
1342  }
1343 
1344  // This is a distribution. Only one lane should insert.
1345  int64_t elementsPerLane = distrType.getShape()[0];
1346  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1347  // tid of extracting thread: pos / elementsPerLane
1348  Value insertingLane = affine::makeComposedAffineApply(
1349  rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1350  // Insert position: pos % elementsPerLane
1352  rewriter, loc, sym0 % elementsPerLane, pos);
1353  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1354  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1355  Value newResult =
1356  rewriter
1357  .create<scf::IfOp>(
1358  loc, isInsertingLane,
1359  /*thenBuilder=*/
1360  [&](OpBuilder &builder, Location loc) {
1361  Value newInsert = builder.create<vector::InsertOp>(
1362  loc, newSource, distributedVec, newPos);
1363  builder.create<scf::YieldOp>(loc, newInsert);
1364  },
1365  /*elseBuilder=*/
1366  [&](OpBuilder &builder, Location loc) {
1367  builder.create<scf::YieldOp>(loc, distributedVec);
1368  })
1369  .getResult(0);
1370  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1371  return success();
1372  }
1373 };
1374 
1375 struct WarpOpInsert : public WarpDistributionPattern {
1376  using Base::Base;
1377  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1378  PatternRewriter &rewriter) const override {
1379  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1380  if (!operand)
1381  return failure();
1382  unsigned int operandNumber = operand->getOperandNumber();
1383  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1384  Location loc = insertOp.getLoc();
1385 
1386  // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1387  if (insertOp.getDestVectorType().getRank() <= 1) {
1388  return failure();
1389  }
1390 
1391  // All following cases are 2d or higher dimensional source vectors.
1392 
1393  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1394  // There is no distribution, this is a broadcast. Simply move the insert
1395  // out of the warp op.
1396  SmallVector<size_t> newRetIndices;
1397  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1398  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1399  {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1400  newRetIndices);
1401  rewriter.setInsertionPointAfter(newWarpOp);
1402  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1403  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1404  Value newResult = rewriter.create<vector::InsertOp>(
1405  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1406  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1407  newResult);
1408  return success();
1409  }
1410 
1411  // Find the distributed dimension. There should be exactly one.
1412  auto distrDestType =
1413  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1414  auto yieldedType = cast<VectorType>(operand->get().getType());
1415  int64_t distrDestDim = -1;
1416  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1417  if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1418  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1419  // support distributing multiple dimensions in the future.
1420  assert(distrDestDim == -1 && "found multiple distributed dims");
1421  distrDestDim = i;
1422  }
1423  }
1424  assert(distrDestDim != -1 && "could not find distributed dimension");
1425 
1426  // Compute the distributed source vector type.
1427  VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1428  SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1429  // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1430  // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1431  // insert a smaller vector<3xf32>.
1432  // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1433  // case, one lane will insert the source vector<96xf32>. The other
1434  // lanes will not do anything.
1435  int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1436  if (distrSrcDim >= 0)
1437  distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1438  auto distrSrcType =
1439  VectorType::get(distrSrcShape, distrDestType.getElementType());
1440 
1441  // Yield source and dest vectors from warp op.
1442  SmallVector<size_t> newRetIndices;
1443  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1444  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1445  {distrSrcType, distrDestType}, newRetIndices);
1446  rewriter.setInsertionPointAfter(newWarpOp);
1447  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1448  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1449 
1450  // Insert into the distributed vector.
1451  Value newResult;
1452  if (distrSrcDim >= 0) {
1453  // Every lane inserts a small piece.
1454  newResult = rewriter.create<vector::InsertOp>(
1455  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1456  } else {
1457  // One lane inserts the entire source vector.
1458  int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1459  SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1460  SmallVector<int64_t> newPos = getAsIntegers(pos);
1461  // tid of inserting lane: pos / elementsPerLane
1462  Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1463  loc, newPos[distrDestDim] / elementsPerLane);
1464  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1465  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1466  // Insert position: pos % elementsPerLane
1467  newPos[distrDestDim] %= elementsPerLane;
1468  auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1469  Value newInsert = builder.create<vector::InsertOp>(
1470  loc, distributedSrc, distributedDest, newPos);
1471  builder.create<scf::YieldOp>(loc, newInsert);
1472  };
1473  auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1474  builder.create<scf::YieldOp>(loc, distributedDest);
1475  };
1476  newResult = rewriter
1477  .create<scf::IfOp>(loc, isInsertingLane,
1478  /*thenBuilder=*/insertingBuilder,
1479  /*elseBuilder=*/nonInsertingBuilder)
1480  .getResult(0);
1481  }
1482 
1483  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1484  return success();
1485  }
1486 };
1487 
1488 struct WarpOpInsertElement : public WarpDistributionPattern {
1489  using Base::Base;
1490  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1491  PatternRewriter &rewriter) const override {
1492  OpOperand *operand =
1493  getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1494  if (!operand)
1495  return failure();
1496  auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1497  SmallVector<OpFoldResult> indices;
1498  if (auto pos = insertOp.getPosition()) {
1499  indices.push_back(pos);
1500  }
1501  rewriter.setInsertionPoint(insertOp);
1502  rewriter.replaceOpWithNewOp<vector::InsertOp>(
1503  insertOp, insertOp.getSource(), insertOp.getDest(), indices);
1504  return success();
1505  }
1506 };
1507 
1508 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1509 /// the scf.ForOp is the last operation in the region so that it doesn't
1510 /// change the order of execution. This creates a new scf.for region after the
1511 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1512 /// WarpExecuteOnLane0Op region. Example:
1513 /// ```
1514 /// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1515 /// ...
1516 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1517 /// -> (vector<128xf32>) {
1518 /// ...
1519 /// scf.yield %r : vector<128xf32>
1520 /// }
1521 /// gpu.yield %v1 : vector<128xf32>
1522 /// }
1523 /// ```
1524 /// To:
1525 /// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1526 /// ...
1527 /// gpu.yield %v : vector<128xf32>
1528 /// }
1529 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1530 /// -> (vector<4xf32>) {
1531 /// %iw = gpu.warp_execute_on_lane_0(%laneid)
1532 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1533 /// ^bb0(%arg: vector<128xf32>):
1534 /// ...
1535 /// gpu.yield %ir : vector<128xf32>
1536 /// }
1537 /// scf.yield %iw : vector<4xf32>
1538 /// }
1539 /// ```
1540 struct WarpOpScfForOp : public WarpDistributionPattern {
1541 
1542  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1543  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1544  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1545  PatternRewriter &rewriter) const override {
1546  auto yield = cast<gpu::YieldOp>(
1547  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1548  // Only pick up forOp if it is the last op in the region.
1549  Operation *lastNode = yield->getPrevNode();
1550  auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1551  if (!forOp)
1552  return failure();
1553  // Collect Values that come from the warp op but are outside the forOp.
1554  // Those Value needs to be returned by the original warpOp and passed to
1555  // the new op.
1556  llvm::SmallSetVector<Value, 32> escapingValues;
1557  SmallVector<Type> inputTypes;
1558  SmallVector<Type> distTypes;
1560  forOp.getBodyRegion(), [&](OpOperand *operand) {
1561  Operation *parent = operand->get().getParentRegion()->getParentOp();
1562  if (warpOp->isAncestor(parent)) {
1563  if (!escapingValues.insert(operand->get()))
1564  return;
1565  Type distType = operand->get().getType();
1566  if (auto vecType = dyn_cast<VectorType>(distType)) {
1567  AffineMap map = distributionMapFn(operand->get());
1568  distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1569  }
1570  inputTypes.push_back(operand->get().getType());
1571  distTypes.push_back(distType);
1572  }
1573  });
1574 
1575  if (llvm::is_contained(distTypes, Type{}))
1576  return failure();
1577 
1578  SmallVector<size_t> newRetIndices;
1579  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1580  rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1581  newRetIndices);
1582  yield = cast<gpu::YieldOp>(
1583  newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1584 
1585  SmallVector<Value> newOperands;
1586  SmallVector<unsigned> resultIdx;
1587  // Collect all the outputs coming from the forOp.
1588  for (OpOperand &yieldOperand : yield->getOpOperands()) {
1589  if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1590  continue;
1591  auto forResult = cast<OpResult>(yieldOperand.get());
1592  newOperands.push_back(
1593  newWarpOp.getResult(yieldOperand.getOperandNumber()));
1594  yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1595  resultIdx.push_back(yieldOperand.getOperandNumber());
1596  }
1597 
1598  OpBuilder::InsertionGuard g(rewriter);
1599  rewriter.setInsertionPointAfter(newWarpOp);
1600 
1601  // Create a new for op outside the region with a WarpExecuteOnLane0Op
1602  // region inside.
1603  auto newForOp = rewriter.create<scf::ForOp>(
1604  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1605  forOp.getStep(), newOperands);
1606  rewriter.setInsertionPointToStart(newForOp.getBody());
1607 
1608  SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1609  newForOp.getRegionIterArgs().end());
1610  SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1611  forOp.getResultTypes().end());
1612  llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1613  for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1614  warpInput.push_back(newWarpOp.getResult(retIdx));
1615  argIndexMapping[escapingValues[i]] = warpInputType.size();
1616  warpInputType.push_back(inputTypes[i]);
1617  }
1618  auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1619  newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1620  newWarpOp.getWarpSize(), warpInput, warpInputType);
1621 
1622  SmallVector<Value> argMapping;
1623  argMapping.push_back(newForOp.getInductionVar());
1624  for (Value args : innerWarp.getBody()->getArguments()) {
1625  argMapping.push_back(args);
1626  }
1627  argMapping.resize(forOp.getBody()->getNumArguments());
1628  SmallVector<Value> yieldOperands;
1629  for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1630  yieldOperands.push_back(operand);
1631  rewriter.eraseOp(forOp.getBody()->getTerminator());
1632  rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1633  rewriter.setInsertionPointToEnd(innerWarp.getBody());
1634  rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
1635  rewriter.setInsertionPointAfter(innerWarp);
1636  if (!innerWarp.getResults().empty())
1637  rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1638  rewriter.eraseOp(forOp);
1639  // Replace the warpOp result coming from the original ForOp.
1640  for (const auto &res : llvm::enumerate(resultIdx)) {
1641  rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1642  newForOp.getResult(res.index()));
1643  newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1644  }
1645  newForOp.walk([&](Operation *op) {
1646  for (OpOperand &operand : op->getOpOperands()) {
1647  auto it = argIndexMapping.find(operand.get());
1648  if (it == argIndexMapping.end())
1649  continue;
1650  operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1651  }
1652  });
1653 
1654  // Finally, hoist out any now uniform code from the inner warp op.
1655  mlir::vector::moveScalarUniformCode(innerWarp);
1656  return success();
1657  }
1658 
1659 private:
1660  DistributionMapFn distributionMapFn;
1661 };
1662 
1663 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1664 /// The vector is reduced in parallel. Currently limited to vector size
1665 /// matching the warpOp size. E.g.:
1666 /// ```
1667 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1668 /// %0 = "some_def"() : () -> (vector<32xf32>)
1669 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1670 /// gpu.yield %1 : f32
1671 /// }
1672 /// ```
1673 /// is lowered to:
1674 /// ```
1675 /// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1676 /// %1 = "some_def"() : () -> (vector<32xf32>)
1677 /// gpu.yield %1 : vector<32xf32>
1678 /// }
1679 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
1680 /// %r = ("warp.reduction %a")
1681 /// ```
1682 struct WarpOpReduction : public WarpDistributionPattern {
1683  WarpOpReduction(MLIRContext *context,
1684  DistributedReductionFn distributedReductionFn,
1685  PatternBenefit benefit = 1)
1686  : WarpDistributionPattern(context, benefit),
1687  distributedReductionFn(std::move(distributedReductionFn)) {}
1688 
1689  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1690  PatternRewriter &rewriter) const override {
1691  OpOperand *yieldOperand =
1692  getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1693  if (!yieldOperand)
1694  return failure();
1695 
1696  auto reductionOp =
1697  cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1698  auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1699  // Only rank 1 vectors supported.
1700  if (vectorType.getRank() != 1)
1701  return rewriter.notifyMatchFailure(
1702  warpOp, "Only rank 1 reductions can be distributed.");
1703  // Only warp_size-sized vectors supported.
1704  if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1705  return rewriter.notifyMatchFailure(
1706  warpOp, "Reduction vector dimension must match was size.");
1707  if (!reductionOp.getType().isIntOrFloat())
1708  return rewriter.notifyMatchFailure(
1709  warpOp, "Reduction distribution currently only supports floats and "
1710  "integer types.");
1711 
1712  int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1713  // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1714  unsigned operandIndex = yieldOperand->getOperandNumber();
1715  SmallVector<Value> yieldValues = {reductionOp.getVector()};
1716  SmallVector<Type> retTypes = {
1717  VectorType::get({numElements}, reductionOp.getType())};
1718  if (reductionOp.getAcc()) {
1719  yieldValues.push_back(reductionOp.getAcc());
1720  retTypes.push_back(reductionOp.getAcc().getType());
1721  }
1722  SmallVector<size_t> newRetIndices;
1723  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1724  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1725  rewriter.setInsertionPointAfter(newWarpOp);
1726 
1727  // Obtain data to reduce for a single lane.
1728  Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1729  // Distribute and reduce across threads.
1730  Value fullReduce =
1731  distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1732  reductionOp.getKind(), newWarpOp.getWarpSize());
1733  if (reductionOp.getAcc()) {
1734  fullReduce = vector::makeArithReduction(
1735  rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1736  newWarpOp.getResult(newRetIndices[1]));
1737  }
1738  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1739  return success();
1740  }
1741 
1742 private:
1743  DistributedReductionFn distributedReductionFn;
1744 };
1745 
1746 } // namespace
1747 
1751  patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
1752 }
1753 
1754 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1755  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1756  unsigned maxNumElementsToExtract, PatternBenefit benefit) {
1757  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
1758  maxNumElementsToExtract, benefit);
1759 }
1760 
1761 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1762  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1763  const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1764  PatternBenefit readBenefit) {
1765  patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1766  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1767  WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1768  WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1769  WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1770  patterns.getContext(), benefit);
1771  patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
1772  benefit);
1773  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
1774  benefit);
1775 }
1776 
1777 void mlir::vector::populateDistributeReduction(
1779  const DistributedReductionFn &distributedReductionFn,
1780  PatternBenefit benefit) {
1781  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
1782  benefit);
1783 }
1784 
1785 /// Helper to know if an op can be hoisted out of the region.
1786 static bool canBeHoisted(Operation *op,
1787  function_ref<bool(Value)> definedOutside) {
1788  return llvm::all_of(op->getOperands(), definedOutside) &&
1789  isMemoryEffectFree(op) && op->getNumRegions() == 0;
1790 }
1791 
1792 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1793  Block *body = warpOp.getBody();
1794 
1795  // Keep track of the ops we want to hoist.
1796  llvm::SmallSetVector<Operation *, 8> opsToMove;
1797 
1798  // Helper to check if a value is or will be defined outside of the region.
1799  auto isDefinedOutsideOfBody = [&](Value value) {
1800  auto *definingOp = value.getDefiningOp();
1801  return (definingOp && opsToMove.count(definingOp)) ||
1802  warpOp.isDefinedOutsideOfRegion(value);
1803  };
1804 
1805  // Do not use walk here, as we do not want to go into nested regions and hoist
1806  // operations from there.
1807  for (auto &op : body->without_terminator()) {
1808  bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1809  return isa<VectorType>(result.getType());
1810  });
1811  if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1812  opsToMove.insert(&op);
1813  }
1814 
1815  // Move all the ops marked as uniform outside of the region.
1816  for (Operation *op : opsToMove)
1817  op->moveBefore(warpOp);
1818 }
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:968
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:345
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:368
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:433
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:850
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:362
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:606
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:582
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:194
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1167
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1217
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2538
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:345
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:43
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:631
This represents an operation in an abstracted form, suitable for use with the builder APIs.