MLIR  20.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/AffineExpr.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include <utility>
23 
24 using namespace mlir;
25 using namespace mlir::vector;
26 using namespace mlir::gpu;
27 
28 /// Currently the distribution map is implicit based on the vector shape. In the
29 /// future it will be part of the op.
30 /// Example:
31 /// ```
32 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
33 /// ...
34 /// gpu.yield %3 : vector<32x16x64xf32>
35 /// }
36 /// ```
37 /// Would have an implicit map of:
38 /// `(d0, d1, d2) -> (d0, d2)`
39 static AffineMap calculateImplicitMap(VectorType sequentialType,
40  VectorType distributedType) {
42  perm.reserve(1);
43  // Check which dimensions of the sequential type are different than the
44  // dimensions of the distributed type to know the distributed dimensions. Then
45  // associate each distributed dimension to an ID in order.
46  for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
47  if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
48  perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
49  }
50  auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
51  distributedType.getContext());
52  return map;
53 }
54 
55 namespace {
56 
57 /// Helper struct to create the load / store operations that permit transit
58 /// through the parallel / sequential and the sequential / parallel boundaries
59 /// when performing `rewriteWarpOpToScfFor`.
60 ///
61 /// The vector distribution dimension is inferred from the vector types.
62 struct DistributedLoadStoreHelper {
63  DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
64  Value laneId, Value zero)
65  : sequentialVal(sequentialVal), distributedVal(distributedVal),
66  laneId(laneId), zero(zero) {
67  sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
68  distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
69  if (sequentialVectorType && distributedVectorType)
70  distributionMap =
71  calculateImplicitMap(sequentialVectorType, distributedVectorType);
72  }
73 
74  Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
75  int64_t distributedSize = distributedVectorType.getDimSize(index);
77  return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
78  ArrayRef<Value>{laneId});
79  }
80 
81  /// Create a store during the process of distributing the
82  /// `vector.warp_execute_on_thread_0` op.
83  /// Vector distribution assumes the following convention regarding the
84  /// temporary buffers that are created to transition values. This **must**
85  /// be properly specified in the `options.warpAllocationFn`:
86  /// 1. scalars of type T transit through a memref<1xT>.
87  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
88  Operation *buildStore(RewriterBase &b, Location loc, Value val,
89  Value buffer) {
90  assert((val == distributedVal || val == sequentialVal) &&
91  "Must store either the preregistered distributed or the "
92  "preregistered sequential value.");
93  // Scalar case can directly use memref.store.
94  if (!isa<VectorType>(val.getType()))
95  return b.create<memref::StoreOp>(loc, val, buffer, zero);
96 
97  // Vector case must use vector::TransferWriteOp which will later lower to
98  // vector.store of memref.store depending on further lowerings.
99  int64_t rank = sequentialVectorType.getRank();
100  SmallVector<Value> indices(rank, zero);
101  if (val == distributedVal) {
102  for (auto dimExpr : distributionMap.getResults()) {
103  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
104  indices[index] = buildDistributedOffset(b, loc, index);
105  }
106  }
107  SmallVector<bool> inBounds(indices.size(), true);
108  return b.create<vector::TransferWriteOp>(
109  loc, val, buffer, indices,
110  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
111  }
112 
113  /// Create a load during the process of distributing the
114  /// `vector.warp_execute_on_thread_0` op.
115  /// Vector distribution assumes the following convention regarding the
116  /// temporary buffers that are created to transition values. This **must**
117  /// be properly specified in the `options.warpAllocationFn`:
118  /// 1. scalars of type T transit through a memref<1xT>.
119  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
120  ///
121  /// When broadcastMode is true, the load is not distributed to account for
122  /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
123  ///
124  /// Example:
125  ///
126  /// ```
127  /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
128  /// gpu.yield %cst : f32
129  /// }
130  /// // Both types are f32. The constant %cst is broadcasted to all lanes.
131  /// ```
132  /// This behavior described in more detail in the documentation of the op.
133  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
134 
135  // Scalar case can directly use memref.store.
136  if (!isa<VectorType>(type))
137  return b.create<memref::LoadOp>(loc, buffer, zero);
138 
139  // Other cases must be vector atm.
140  // Vector case must use vector::TransferReadOp which will later lower to
141  // vector.read of memref.read depending on further lowerings.
142  assert((type == distributedVectorType || type == sequentialVectorType) &&
143  "Must store either the preregistered distributed or the "
144  "preregistered sequential type.");
145  SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
146  if (type == distributedVectorType) {
147  for (auto dimExpr : distributionMap.getResults()) {
148  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
149  indices[index] = buildDistributedOffset(b, loc, index);
150  }
151  }
152  SmallVector<bool> inBounds(indices.size(), true);
153  return b.create<vector::TransferReadOp>(
154  loc, cast<VectorType>(type), buffer, indices,
155  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
156  }
157 
158  Value sequentialVal, distributedVal, laneId, zero;
159  VectorType sequentialVectorType, distributedVectorType;
160  AffineMap distributionMap;
161 };
162 
163 } // namespace
164 
165 // Clones `op` into a new operation that takes `operands` and returns
166 // `resultTypes`.
168  Location loc, Operation *op,
169  ArrayRef<Value> operands,
170  ArrayRef<Type> resultTypes) {
171  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
172  op->getAttrs());
173  return rewriter.create(res);
174 }
175 
176 namespace {
177 
178 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
179 /// thread `laneId` executes the entirety of the computation.
180 ///
181 /// After the transformation:
182 /// - the IR within the scf.if op can be thought of as executing sequentially
183 /// (from the point of view of threads along `laneId`).
184 /// - the IR outside of the scf.if op can be thought of as executing in
185 /// parallel (from the point of view of threads along `laneId`).
186 ///
187 /// Values that need to transit through the parallel / sequential and the
188 /// sequential / parallel boundaries do so via reads and writes to a temporary
189 /// memory location.
190 ///
191 /// The transformation proceeds in multiple steps:
192 /// 1. Create the scf.if op.
193 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
194 /// within the scf.if to transit the values captured from above.
195 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
196 /// consistent within the scf.if.
197 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
198 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to
199 /// transit the values returned by the op.
200 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
201 /// consistent after the scf.if.
202 /// 7. Perform late cleanups.
203 ///
204 /// All this assumes the vector distribution occurs along the most minor
205 /// distributed vector dimension.
206 struct WarpOpToScfIfPattern : public WarpDistributionPattern {
207  WarpOpToScfIfPattern(MLIRContext *context,
209  PatternBenefit benefit = 1)
210  : WarpDistributionPattern(context, benefit), options(options) {}
211 
212  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
213  PatternRewriter &rewriter) const override {
214  assert(warpOp.getBodyRegion().hasOneBlock() &&
215  "expected WarpOp with single block");
216  Block *warpOpBody = &warpOp.getBodyRegion().front();
217  Location loc = warpOp.getLoc();
218 
219  // Passed all checks. Start rewriting.
220  OpBuilder::InsertionGuard g(rewriter);
221  rewriter.setInsertionPoint(warpOp);
222 
223  // Step 1: Create scf.if op.
224  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
225  Value isLane0 = rewriter.create<arith::CmpIOp>(
226  loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
227  auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
228  /*withElseRegion=*/false);
229  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
230 
231  // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
232  // reads within the scf.if to transit the values captured from above.
233  SmallVector<Value> bbArgReplacements;
234  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
235  Value sequentialVal = warpOpBody->getArgument(it.index());
236  Value distributedVal = it.value();
237  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
238  warpOp.getLaneid(), c0);
239 
240  // Create buffer before the ifOp.
241  rewriter.setInsertionPoint(ifOp);
242  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
243  sequentialVal.getType());
244  // Store distributed vector into buffer, before the ifOp.
245  helper.buildStore(rewriter, loc, distributedVal, buffer);
246  // Load sequential vector from buffer, inside the ifOp.
247  rewriter.setInsertionPointToStart(ifOp.thenBlock());
248  bbArgReplacements.push_back(
249  helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
250  }
251 
252  // Step 3. Insert sync after all the stores and before all the loads.
253  if (!warpOp.getArgs().empty()) {
254  rewriter.setInsertionPoint(ifOp);
255  options.warpSyncronizationFn(loc, rewriter, warpOp);
256  }
257 
258  // Step 4. Move body of warpOp to ifOp.
259  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
260 
261  // Step 5. Insert appropriate writes within scf.if and reads after the
262  // scf.if to transit the values returned by the op.
263  // TODO: at this point, we can reuse the shared memory from previous
264  // buffers.
265  SmallVector<Value> replacements;
266  auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
267  Location yieldLoc = yieldOp.getLoc();
268  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
269  Value sequentialVal = it.value();
270  Value distributedVal = warpOp->getResult(it.index());
271  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
272  warpOp.getLaneid(), c0);
273 
274  // Create buffer before the ifOp.
275  rewriter.setInsertionPoint(ifOp);
276  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
277  sequentialVal.getType());
278 
279  // Store yielded value into buffer, inside the ifOp, before the
280  // terminator.
281  rewriter.setInsertionPoint(yieldOp);
282  helper.buildStore(rewriter, loc, sequentialVal, buffer);
283 
284  // Load distributed value from buffer, after the warpOp.
285  rewriter.setInsertionPointAfter(ifOp);
286  // Result type and yielded value type are the same. This is a broadcast.
287  // E.g.:
288  // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
289  // gpu.yield %cst : f32
290  // }
291  // Both types are f32. The constant %cst is broadcasted to all lanes.
292  // This is described in more detail in the documentation of the op.
293  replacements.push_back(
294  helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
295  }
296 
297  // Step 6. Insert sync after all the stores and before all the loads.
298  if (!yieldOp.getOperands().empty()) {
299  rewriter.setInsertionPointAfter(ifOp);
300  options.warpSyncronizationFn(loc, rewriter, warpOp);
301  }
302 
303  // Step 7. Delete terminator and add empty scf.yield.
304  rewriter.eraseOp(yieldOp);
305  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
306  rewriter.create<scf::YieldOp>(yieldLoc);
307 
308  // Compute replacements for WarpOp results.
309  rewriter.replaceOp(warpOp, replacements);
310 
311  return success();
312  }
313 
314 private:
316 };
317 
318 /// Return the distributed vector type based on the original type and the
319 /// distribution map. The map is expected to have a dimension equal to the
320 /// original type rank and should be a projection where the results are the
321 /// distributed dimensions. The number of results should be equal to the number
322 /// of warp sizes which is currently limited to 1.
323 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
324 /// and a warp size of 16 would distribute the second dimension (associated to
325 /// d1) and return vector<16x2x64>
326 static VectorType getDistributedType(VectorType originalType, AffineMap map,
327  int64_t warpSize) {
328  SmallVector<int64_t> targetShape(originalType.getShape());
329  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
330  unsigned position = map.getDimPosition(i);
331  if (targetShape[position] % warpSize != 0) {
332  if (warpSize % targetShape[position] != 0) {
333  return VectorType();
334  }
335  warpSize /= targetShape[position];
336  targetShape[position] = 1;
337  continue;
338  }
339  targetShape[position] = targetShape[position] / warpSize;
340  warpSize = 1;
341  break;
342  }
343  if (warpSize != 1) {
344  return VectorType();
345  }
346  VectorType targetType =
347  VectorType::get(targetShape, originalType.getElementType());
348  return targetType;
349 }
350 
351 /// Distribute transfer_write ops based on the affine map returned by
352 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
353 /// will not be distributed (it should be less than the warp size).
354 ///
355 /// Example:
356 /// ```
357 /// %0 = gpu.warp_execute_on_lane_0(%id){
358 /// ...
359 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
360 /// gpu.yield
361 /// }
362 /// ```
363 /// To
364 /// ```
365 /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
366 /// ...
367 /// gpu.yield %v : vector<32xf32>
368 /// }
369 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
370 struct WarpOpTransferWrite : public WarpDistributionPattern {
371  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
372  unsigned maxNumElementsToExtract, PatternBenefit b = 1)
373  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
374  maxNumElementsToExtract(maxNumElementsToExtract) {}
375 
376  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
377  /// are multiples of the distribution ratio are supported at the moment.
378  LogicalResult tryDistributeOp(RewriterBase &rewriter,
379  vector::TransferWriteOp writeOp,
380  WarpExecuteOnLane0Op warpOp) const {
381  VectorType writtenVectorType = writeOp.getVectorType();
382 
383  // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
384  // to separate it from the rest.
385  if (writtenVectorType.getRank() == 0)
386  return failure();
387 
388  // 2. Compute the distributed type.
389  AffineMap map = distributionMapFn(writeOp.getVector());
390  VectorType targetType =
391  getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
392  if (!targetType)
393  return failure();
394 
395  // 2.5 Compute the distributed type for the new mask;
396  VectorType maskType;
397  if (writeOp.getMask()) {
398  // TODO: Distribution of masked writes with non-trivial permutation maps
399  // requires the distribution of the mask to elementwise match the
400  // distribution of the permuted written vector. Currently the details
401  // of which lane is responsible for which element is captured strictly
402  // by shape information on the warp op, and thus requires materializing
403  // the permutation in IR.
404  if (!writeOp.getPermutationMap().isMinorIdentity())
405  return failure();
406  maskType =
407  getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
408  }
409 
410  // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
411  // the rest.
412  vector::TransferWriteOp newWriteOp =
413  cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
414 
415  // 4. Reindex the write using the distribution map.
416  auto newWarpOp =
417  newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
418 
419  // Delinearize the lane id based on the way threads are divided across the
420  // vector. To get the number of threads per vector dimension, divide the
421  // sequential size by the distributed size along each dim.
422  rewriter.setInsertionPoint(newWriteOp);
423  SmallVector<OpFoldResult> delinearizedIdSizes;
424  for (auto [seqSize, distSize] :
425  llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
426  assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
427  delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
428  }
429  SmallVector<Value> delinearized;
430  if (map.getNumResults() > 1) {
431  delinearized = rewriter
432  .create<mlir::affine::AffineDelinearizeIndexOp>(
433  newWarpOp.getLoc(), newWarpOp.getLaneid(),
434  delinearizedIdSizes)
435  .getResults();
436  } else {
437  // If there is only one map result, we can elide the delinearization
438  // op and use the lane id directly.
439  delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
440  }
441 
442  AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
443  Location loc = newWriteOp.getLoc();
444  SmallVector<Value> indices(newWriteOp.getIndices().begin(),
445  newWriteOp.getIndices().end());
446  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
447  AffineExpr d0, d1;
448  bindDims(newWarpOp.getContext(), d0, d1);
449  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
450  if (!indexExpr)
451  continue;
452  unsigned indexPos = indexExpr.getPosition();
453  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
454  Value laneId = delinearized[vectorPos];
455  auto scale =
456  rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
457  indices[indexPos] = affine::makeComposedAffineApply(
458  rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
459  }
460  newWriteOp.getIndicesMutable().assign(indices);
461 
462  return success();
463  }
464 
465  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
466  LogicalResult tryExtractOp(RewriterBase &rewriter,
467  vector::TransferWriteOp writeOp,
468  WarpExecuteOnLane0Op warpOp) const {
469  Location loc = writeOp.getLoc();
470  VectorType vecType = writeOp.getVectorType();
471 
472  if (vecType.getNumElements() > maxNumElementsToExtract) {
473  return rewriter.notifyMatchFailure(
474  warpOp,
475  llvm::formatv(
476  "writes more elements ({0}) than allowed to extract ({1})",
477  vecType.getNumElements(), maxNumElementsToExtract));
478  }
479 
480  // Do not process warp ops that contain only TransferWriteOps.
481  if (llvm::all_of(warpOp.getOps(),
482  llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
483  return failure();
484 
485  SmallVector<Value> yieldValues = {writeOp.getVector()};
486  SmallVector<Type> retTypes = {vecType};
487  SmallVector<size_t> newRetIndices;
488  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
489  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
490  rewriter.setInsertionPointAfter(newWarpOp);
491 
492  // Create a second warp op that contains only writeOp.
493  auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
494  loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
495  Block &body = secondWarpOp.getBodyRegion().front();
496  rewriter.setInsertionPointToStart(&body);
497  auto newWriteOp =
498  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
499  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
500  rewriter.eraseOp(writeOp);
501  rewriter.create<gpu::YieldOp>(newWarpOp.getLoc());
502  return success();
503  }
504 
505  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
506  PatternRewriter &rewriter) const override {
507  auto yield = cast<gpu::YieldOp>(
508  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
509  Operation *lastNode = yield->getPrevNode();
510  auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
511  if (!writeOp)
512  return failure();
513 
514  Value maybeMask = writeOp.getMask();
515  if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
516  return writeOp.getVector() == value ||
517  (maybeMask && maybeMask == value) ||
518  warpOp.isDefinedOutsideOfRegion(value);
519  }))
520  return failure();
521 
522  if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
523  return success();
524 
525  // Masked writes not supported for extraction.
526  if (writeOp.getMask())
527  return failure();
528 
529  if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
530  return success();
531 
532  return failure();
533  }
534 
535 private:
536  /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
537  /// execute op with the proper return type. The new write op is updated to
538  /// write the result of the new warp execute op. The old `writeOp` is deleted.
539  vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
540  WarpExecuteOnLane0Op warpOp,
541  vector::TransferWriteOp writeOp,
542  VectorType targetType,
543  VectorType maybeMaskType) const {
544  assert(writeOp->getParentOp() == warpOp &&
545  "write must be nested immediately under warp");
546  OpBuilder::InsertionGuard g(rewriter);
547  SmallVector<size_t> newRetIndices;
548  WarpExecuteOnLane0Op newWarpOp;
549  if (maybeMaskType) {
550  newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
551  rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
552  TypeRange{targetType, maybeMaskType}, newRetIndices);
553  } else {
554  newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
555  rewriter, warpOp, ValueRange{{writeOp.getVector()}},
556  TypeRange{targetType}, newRetIndices);
557  }
558  rewriter.setInsertionPointAfter(newWarpOp);
559  auto newWriteOp =
560  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
561  rewriter.eraseOp(writeOp);
562  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
563  if (maybeMaskType)
564  newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
565  return newWriteOp;
566  }
567 
568  DistributionMapFn distributionMapFn;
569  unsigned maxNumElementsToExtract = 1;
570 };
571 
572 /// Sink out elementwise op feeding into a warp op yield.
573 /// ```
574 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
575 /// ...
576 /// %3 = arith.addf %1, %2 : vector<32xf32>
577 /// gpu.yield %3 : vector<32xf32>
578 /// }
579 /// ```
580 /// To
581 /// ```
582 /// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
583 /// vector<1xf32>, vector<1xf32>) {
584 /// ...
585 /// %4 = arith.addf %2, %3 : vector<32xf32>
586 /// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
587 /// vector<32xf32>
588 /// }
589 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
590 struct WarpOpElementwise : public WarpDistributionPattern {
591  using Base::Base;
592  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
593  PatternRewriter &rewriter) const override {
594  OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
596  });
597  if (!yieldOperand)
598  return failure();
599 
600  Operation *elementWise = yieldOperand->get().getDefiningOp();
601  unsigned operandIndex = yieldOperand->getOperandNumber();
602  Value distributedVal = warpOp.getResult(operandIndex);
603  SmallVector<Value> yieldValues;
604  SmallVector<Type> retTypes;
605  Location loc = warpOp.getLoc();
606  for (OpOperand &operand : elementWise->getOpOperands()) {
607  Type targetType;
608  if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
609  // If the result type is a vector, the operands must also be vectors.
610  auto operandType = cast<VectorType>(operand.get().getType());
611  targetType =
612  VectorType::get(vecType.getShape(), operandType.getElementType());
613  } else {
614  auto operandType = operand.get().getType();
615  assert(!isa<VectorType>(operandType) &&
616  "unexpected yield of vector from op with scalar result type");
617  targetType = operandType;
618  }
619  retTypes.push_back(targetType);
620  yieldValues.push_back(operand.get());
621  }
622  SmallVector<size_t> newRetIndices;
623  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
624  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
625  rewriter.setInsertionPointAfter(newWarpOp);
626  SmallVector<Value> newOperands(elementWise->getOperands().begin(),
627  elementWise->getOperands().end());
628  for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
629  newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
630  }
631  OpBuilder::InsertionGuard g(rewriter);
632  rewriter.setInsertionPointAfter(newWarpOp);
634  rewriter, loc, elementWise, newOperands,
635  {newWarpOp.getResult(operandIndex).getType()});
636  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
637  newOp->getResult(0));
638  return success();
639  }
640 };
641 
642 /// Sink out splat constant op feeding into a warp op yield.
643 /// ```
644 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
645 /// ...
646 /// %cst = arith.constant dense<2.0> : vector<32xf32>
647 /// gpu.yield %cst : vector<32xf32>
648 /// }
649 /// ```
650 /// To
651 /// ```
652 /// gpu.warp_execute_on_lane_0(%arg0 {
653 /// ...
654 /// }
655 /// %0 = arith.constant dense<2.0> : vector<1xf32>
656 struct WarpOpConstant : public WarpDistributionPattern {
657  using Base::Base;
658  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
659  PatternRewriter &rewriter) const override {
660  OpOperand *yieldOperand =
661  getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
662  if (!yieldOperand)
663  return failure();
664  auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
665  auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
666  if (!dense)
667  return failure();
668  // Notify the rewriter that the warp op is changing (see the comment on
669  // the WarpOpTransferRead pattern).
670  rewriter.startOpModification(warpOp);
671  unsigned operandIndex = yieldOperand->getOperandNumber();
672  Attribute scalarAttr = dense.getSplatValue<Attribute>();
673  auto newAttr = DenseElementsAttr::get(
674  cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
675  Location loc = warpOp.getLoc();
676  rewriter.setInsertionPointAfter(warpOp);
677  Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
678  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
679  rewriter.finalizeOpModification(warpOp);
680  return success();
681  }
682 };
683 
684 /// Sink out transfer_read op feeding into a warp op yield.
685 /// ```
686 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
687 /// ...
688 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
689 // vector<32xf32>
690 /// gpu.yield %2 : vector<32xf32>
691 /// }
692 /// ```
693 /// To
694 /// ```
695 /// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
696 /// vector<1xf32>, vector<1xf32>) {
697 /// ...
698 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
699 /// vector<32xf32> gpu.yield %2 : vector<32xf32>
700 /// }
701 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
702 struct WarpOpTransferRead : public WarpDistributionPattern {
703  using Base::Base;
704  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
705  PatternRewriter &rewriter) const override {
706  // Try to find a distributable yielded read. Note that this pattern can
707  // still fail at the end after distribution, in which case this might have
708  // missed another distributable read.
709  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
710  // Don't duplicate transfer_read ops when distributing.
711  return isa<vector::TransferReadOp>(op) && op->hasOneUse();
712  });
713  if (!operand)
714  return rewriter.notifyMatchFailure(
715  warpOp, "warp result is not a vector.transfer_read op");
716  auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
717 
718  // Source must be defined outside of the region.
719  if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
720  return rewriter.notifyMatchFailure(
721  read, "source must be defined outside of the region");
722 
723  unsigned operandIndex = operand->getOperandNumber();
724  Value distributedVal = warpOp.getResult(operandIndex);
725 
726  SmallVector<Value, 4> indices(read.getIndices().begin(),
727  read.getIndices().end());
728  auto sequentialType = cast<VectorType>(read.getResult().getType());
729  auto distributedType = cast<VectorType>(distributedVal.getType());
730  AffineMap map = calculateImplicitMap(sequentialType, distributedType);
731  AffineMap indexMap = map.compose(read.getPermutationMap());
732 
733  // Try to delinearize the lane ID to match the rank expected for
734  // distribution.
735  SmallVector<Value> delinearizedIds;
736  if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
737  distributedType.getShape(), warpOp.getWarpSize(),
738  warpOp.getLaneid(), delinearizedIds)) {
739  return rewriter.notifyMatchFailure(
740  read, "cannot delinearize lane ID for distribution");
741  }
742  assert(!delinearizedIds.empty() || map.getNumResults() == 0);
743 
744  // Distribute indices and the mask (if present).
745  OpBuilder::InsertionGuard g(rewriter);
746  SmallVector<Value> additionalResults(indices.begin(), indices.end());
747  SmallVector<Type> additionalResultTypes(indices.size(),
748  rewriter.getIndexType());
749  additionalResults.push_back(read.getPadding());
750  additionalResultTypes.push_back(read.getPadding().getType());
751 
752  bool hasMask = false;
753  if (read.getMask()) {
754  hasMask = true;
755  // TODO: Distribution of masked reads with non-trivial permutation maps
756  // requires the distribution of the mask to elementwise match the
757  // distribution of the permuted written vector. Currently the details
758  // of which lane is responsible for which element is captured strictly
759  // by shape information on the warp op, and thus requires materializing
760  // the permutation in IR.
761  if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
762  return rewriter.notifyMatchFailure(
763  read, "non-trivial permutation maps not supported");
764  VectorType maskType =
765  getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
766  additionalResults.push_back(read.getMask());
767  additionalResultTypes.push_back(maskType);
768  }
769 
770  SmallVector<size_t> newRetIndices;
771  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
772  rewriter, warpOp, additionalResults, additionalResultTypes,
773  newRetIndices);
774  distributedVal = newWarpOp.getResult(operandIndex);
775 
776  // Distributed indices were appended first.
777  SmallVector<Value> newIndices;
778  for (int64_t i = 0, e = indices.size(); i < e; ++i)
779  newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
780 
781  rewriter.setInsertionPointAfter(newWarpOp);
782  for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
783  AffineExpr d0, d1;
784  bindDims(read.getContext(), d0, d1);
785  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
786  if (!indexExpr)
787  continue;
788  unsigned indexPos = indexExpr.getPosition();
789  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
790  int64_t scale = distributedType.getDimSize(vectorPos);
791  newIndices[indexPos] = affine::makeComposedAffineApply(
792  rewriter, read.getLoc(), d0 + scale * d1,
793  {newIndices[indexPos], delinearizedIds[vectorPos]});
794  }
795 
796  // Distributed padding value was appended right after the indices.
797  Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
798  // Distributed mask value was added at the end (if the op has a mask).
799  Value newMask =
800  hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
801  : Value();
802  auto newRead = rewriter.create<vector::TransferReadOp>(
803  read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
804  read.getPermutationMapAttr(), newPadding, newMask,
805  read.getInBoundsAttr());
806 
807  rewriter.replaceAllUsesWith(distributedVal, newRead);
808  return success();
809  }
810 };
811 
812 /// Remove any result that has no use along with the matching yieldOp operand.
813 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
814 struct WarpOpDeadResult : public WarpDistributionPattern {
815  using Base::Base;
816  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
817  PatternRewriter &rewriter) const override {
818  SmallVector<Type> newResultTypes;
819  newResultTypes.reserve(warpOp->getNumResults());
820  SmallVector<Value> newYieldValues;
821  newYieldValues.reserve(warpOp->getNumResults());
822  DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
823  DenseMap<OpResult, int64_t> dedupResultPositionMap;
824  auto yield = cast<gpu::YieldOp>(
825  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
826 
827  // Some values may be yielded multiple times and correspond to multiple
828  // results. Deduplicating occurs by taking each result with its matching
829  // yielded value, and:
830  // 1. recording the unique first position at which the value is yielded.
831  // 2. recording for the result, the first position at which the dedup'ed
832  // value is yielded.
833  // 3. skipping from the new result types / new yielded values any result
834  // that has no use or whose yielded value has already been seen.
835  for (OpResult result : warpOp.getResults()) {
836  Value yieldOperand = yield.getOperand(result.getResultNumber());
837  auto it = dedupYieldOperandPositionMap.insert(
838  std::make_pair(yieldOperand, newResultTypes.size()));
839  dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
840  if (result.use_empty() || !it.second)
841  continue;
842  newResultTypes.push_back(result.getType());
843  newYieldValues.push_back(yieldOperand);
844  }
845  // No modification, exit early.
846  if (yield.getNumOperands() == newYieldValues.size())
847  return failure();
848  // Move the body of the old warpOp to a new warpOp.
849  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
850  rewriter, warpOp, newYieldValues, newResultTypes);
851 
852  // Simplify the new warp op after dropping dead results.
853  newWarpOp.getBody()->walk([&](Operation *op) {
854  if (isOpTriviallyDead(op))
855  rewriter.eraseOp(op);
856  });
857 
858  // Replace results of the old warpOp by the new, deduplicated results.
859  SmallVector<Value> newValues;
860  newValues.reserve(warpOp->getNumResults());
861  for (OpResult result : warpOp.getResults()) {
862  if (result.use_empty())
863  newValues.push_back(Value());
864  else
865  newValues.push_back(
866  newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
867  }
868  rewriter.replaceOp(warpOp, newValues);
869  return success();
870  }
871 };
872 
873 // If an operand is directly yielded out of the region we can forward it
874 // directly and it doesn't need to go through the region.
875 struct WarpOpForwardOperand : public WarpDistributionPattern {
876  using Base::Base;
877  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
878  PatternRewriter &rewriter) const override {
879  SmallVector<Type> resultTypes;
880  SmallVector<Value> yieldValues;
881  auto yield = cast<gpu::YieldOp>(
882  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
883  Value valForwarded;
884  unsigned resultIndex;
885  for (OpOperand &operand : yield->getOpOperands()) {
886  Value result = warpOp.getResult(operand.getOperandNumber());
887  if (result.use_empty())
888  continue;
889 
890  // Assume all the values coming from above are uniform.
891  if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
892  if (result.getType() != operand.get().getType())
893  continue;
894  valForwarded = operand.get();
895  resultIndex = operand.getOperandNumber();
896  break;
897  }
898  auto arg = dyn_cast<BlockArgument>(operand.get());
899  if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
900  continue;
901  Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
902  if (result.getType() != warpOperand.getType())
903  continue;
904  valForwarded = warpOperand;
905  resultIndex = operand.getOperandNumber();
906  break;
907  }
908  if (!valForwarded)
909  return failure();
910  // Notify the rewriter that the warp op is changing (see the comment on
911  // the WarpOpTransferRead pattern).
912  rewriter.startOpModification(warpOp);
913  rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
914  rewriter.finalizeOpModification(warpOp);
915  return success();
916  }
917 };
918 
919 struct WarpOpBroadcast : public WarpDistributionPattern {
920  using Base::Base;
921  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
922  PatternRewriter &rewriter) const override {
923  OpOperand *operand =
924  getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
925  if (!operand)
926  return failure();
927  unsigned int operandNumber = operand->getOperandNumber();
928  auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
929  Location loc = broadcastOp.getLoc();
930  auto destVecType =
931  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
932  Value broadcastSrc = broadcastOp.getSource();
933  Type broadcastSrcType = broadcastSrc.getType();
934 
935  // Check that the broadcast actually spans a set of values uniformly across
936  // all threads. In other words, check that each thread can reconstruct
937  // their own broadcast.
938  // For that we simply check that the broadcast we want to build makes sense.
939  if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
941  return failure();
942  SmallVector<size_t> newRetIndices;
943  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
944  rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
945  rewriter.setInsertionPointAfter(newWarpOp);
946  Value broadcasted = rewriter.create<vector::BroadcastOp>(
947  loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
948  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
949  broadcasted);
950  return success();
951  }
952 };
953 
954 /// Pattern to move shape cast out of the warp op. shape cast is basically a
955 /// no-op for warp distribution; we need to handle the shape though.
956 struct WarpOpShapeCast : public WarpDistributionPattern {
957  using Base::Base;
958  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
959  PatternRewriter &rewriter) const override {
960  OpOperand *operand =
961  getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
962  if (!operand)
963  return failure();
964 
965  auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
966 
967  unsigned int operandNumber = operand->getOperandNumber();
968  auto castDistributedType =
969  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
970  VectorType castOriginalType = oldCastOp.getSourceVectorType();
971  VectorType castResultType = castDistributedType;
972 
973  // We expect the distributed type to have a smaller rank than the original
974  // type. Prepend with size-one dimensions to make them the same.
975  unsigned castDistributedRank = castDistributedType.getRank();
976  unsigned castOriginalRank = castOriginalType.getRank();
977  if (castDistributedRank < castOriginalRank) {
978  SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
979  llvm::append_range(shape, castDistributedType.getShape());
980  castDistributedType =
981  VectorType::get(shape, castDistributedType.getElementType());
982  }
983 
984  SmallVector<size_t> newRetIndices;
985  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
986  rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
987  newRetIndices);
988  rewriter.setInsertionPointAfter(newWarpOp);
989  Value newCast = rewriter.create<vector::ShapeCastOp>(
990  oldCastOp.getLoc(), castResultType,
991  newWarpOp->getResult(newRetIndices[0]));
992  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
993  return success();
994  }
995 };
996 
997 /// Sink out vector.create_mask op feeding into a warp op yield.
998 /// ```
999 /// %0 = ...
1000 /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1001 /// ...
1002 /// %mask = vector.create_mask %0 : vector<32xi1>
1003 /// gpu.yield %mask : vector<32xi1>
1004 /// }
1005 /// ```
1006 /// To
1007 /// ```
1008 /// %0 = ...
1009 /// gpu.warp_execute_on_lane_0(%arg0) {
1010 /// ...
1011 /// }
1012 /// %cmp = arith.cmpi ult, %laneid, %0
1013 /// %ub = arith.select %cmp, %c0, %c1
1014 /// %1 = vector.create_mask %ub : vector<1xi1>
1015 struct WarpOpCreateMask : public WarpDistributionPattern {
1016  using Base::Base;
1017  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1018  PatternRewriter &rewriter) const override {
1019  OpOperand *yieldOperand =
1020  getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1021  if (!yieldOperand)
1022  return failure();
1023 
1024  auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1025 
1026  // Early exit if any values needed for calculating the new mask indices
1027  // are defined inside the warp op.
1028  if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1029  return warpOp.isDefinedOutsideOfRegion(value);
1030  }))
1031  return failure();
1032 
1033  Location loc = mask.getLoc();
1034  unsigned operandIndex = yieldOperand->getOperandNumber();
1035 
1036  auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1037  VectorType seqType = mask.getVectorType();
1038  ArrayRef<int64_t> seqShape = seqType.getShape();
1039  ArrayRef<int64_t> distShape = distType.getShape();
1040 
1041  rewriter.setInsertionPointAfter(warpOp);
1042 
1043  // Delinearize the lane ID for constructing the distributed mask sizes.
1044  SmallVector<Value> delinearizedIds;
1045  if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1046  warpOp.getWarpSize(), warpOp.getLaneid(),
1047  delinearizedIds))
1048  return rewriter.notifyMatchFailure(
1049  mask, "cannot delinearize lane ID for distribution");
1050  assert(!delinearizedIds.empty());
1051 
1052  // Notify the rewriter that the warp op is changing (see the comment on
1053  // the WarpOpTransferRead pattern).
1054  rewriter.startOpModification(warpOp);
1055 
1056  AffineExpr s0, s1;
1057  bindSymbols(rewriter.getContext(), s0, s1);
1058  SmallVector<Value> newOperands;
1059  for (int i = 0, e = distShape.size(); i < e; ++i) {
1060  // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1061  // find the distance from the largest mask index owned by this lane to the
1062  // original mask size. `vector.create_mask` implicitly clamps mask
1063  // operands to the range [0, mask_vector_size[i]], or in other words, the
1064  // mask sizes are always in the range [0, mask_vector_size[i]).
1066  rewriter, loc, s1 - s0 * distShape[i],
1067  {delinearizedIds[i], mask.getOperand(i)});
1068  newOperands.push_back(maskDimIdx);
1069  }
1070 
1071  auto newMask =
1072  rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1073  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1074  rewriter.finalizeOpModification(warpOp);
1075  return success();
1076  }
1077 };
1078 
1079 /// Pattern to move out vector.extract of single element vector. Those don't
1080 /// need to be distributed and can just be propagated outside of the region.
1081 struct WarpOpExtract : public WarpDistributionPattern {
1082  using Base::Base;
1083  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1084  PatternRewriter &rewriter) const override {
1085  OpOperand *operand =
1086  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1087  if (!operand)
1088  return failure();
1089  unsigned int operandNumber = operand->getOperandNumber();
1090  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1091  VectorType extractSrcType = extractOp.getSourceVectorType();
1092  Location loc = extractOp.getLoc();
1093 
1094  // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1095  if (extractSrcType.getRank() <= 1) {
1096  return failure();
1097  }
1098 
1099  // All following cases are 2d or higher dimensional source vectors.
1100 
1101  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1102  // There is no distribution, this is a broadcast. Simply move the extract
1103  // out of the warp op.
1104  // TODO: This could be optimized. E.g., in case of a scalar result, let
1105  // one lane extract and shuffle the result to all other lanes (same as
1106  // the 1d case).
1107  SmallVector<size_t> newRetIndices;
1108  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1109  rewriter, warpOp, {extractOp.getVector()},
1110  {extractOp.getSourceVectorType()}, newRetIndices);
1111  rewriter.setInsertionPointAfter(newWarpOp);
1112  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1113  // Extract from distributed vector.
1114  Value newExtract = rewriter.create<vector::ExtractOp>(
1115  loc, distributedVec, extractOp.getMixedPosition());
1116  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1117  newExtract);
1118  return success();
1119  }
1120 
1121  // Find the distributed dimension. There should be exactly one.
1122  auto distributedType =
1123  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1124  auto yieldedType = cast<VectorType>(operand->get().getType());
1125  int64_t distributedDim = -1;
1126  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1127  if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1128  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1129  // support distributing multiple dimensions in the future.
1130  assert(distributedDim == -1 && "found multiple distributed dims");
1131  distributedDim = i;
1132  }
1133  }
1134  assert(distributedDim != -1 && "could not find distributed dimension");
1135  (void)distributedDim;
1136 
1137  // Yield source vector from warp op.
1138  SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1139  for (int i = 0; i < distributedType.getRank(); ++i)
1140  newDistributedShape[i + extractOp.getNumIndices()] =
1141  distributedType.getDimSize(i);
1142  auto newDistributedType =
1143  VectorType::get(newDistributedShape, distributedType.getElementType());
1144  SmallVector<size_t> newRetIndices;
1145  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1146  rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1147  newRetIndices);
1148  rewriter.setInsertionPointAfter(newWarpOp);
1149  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1150  // Extract from distributed vector.
1151  Value newExtract = rewriter.create<vector::ExtractOp>(
1152  loc, distributedVec, extractOp.getMixedPosition());
1153  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1154  newExtract);
1155  return success();
1156  }
1157 };
1158 
1159 /// Pattern to move out vector.extract with a scalar result.
1160 /// Only supports 1-D and 0-D sources for now.
1161 struct WarpOpExtractScalar : public WarpDistributionPattern {
1162  WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1163  PatternBenefit b = 1)
1164  : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1165  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1166  PatternRewriter &rewriter) const override {
1167  OpOperand *operand =
1168  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1169  if (!operand)
1170  return failure();
1171  unsigned int operandNumber = operand->getOperandNumber();
1172  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1173  VectorType extractSrcType = extractOp.getSourceVectorType();
1174  // Only supports 1-D or 0-D sources for now.
1175  if (extractSrcType.getRank() > 1) {
1176  return rewriter.notifyMatchFailure(
1177  extractOp, "only 0-D or 1-D source supported for now");
1178  }
1179  // TODO: Supported shuffle types should be parameterizable, similar to
1180  // `WarpShuffleFromIdxFn`.
1181  if (!extractSrcType.getElementType().isF32() &&
1182  !extractSrcType.getElementType().isInteger(32))
1183  return rewriter.notifyMatchFailure(
1184  extractOp, "only f32/i32 element types are supported");
1185  bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1186  Type elType = extractSrcType.getElementType();
1187  VectorType distributedVecType;
1188  if (!is0dOrVec1Extract) {
1189  assert(extractSrcType.getRank() == 1 &&
1190  "expected that extract src rank is 0 or 1");
1191  if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1192  return failure();
1193  int64_t elementsPerLane =
1194  extractSrcType.getShape()[0] / warpOp.getWarpSize();
1195  distributedVecType = VectorType::get({elementsPerLane}, elType);
1196  } else {
1197  distributedVecType = extractSrcType;
1198  }
1199  // Yield source vector and position (if present) from warp op.
1200  SmallVector<Value> additionalResults{extractOp.getVector()};
1201  SmallVector<Type> additionalResultTypes{distributedVecType};
1202  additionalResults.append(
1203  SmallVector<Value>(extractOp.getDynamicPosition()));
1204  additionalResultTypes.append(
1205  SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1206 
1207  Location loc = extractOp.getLoc();
1208  SmallVector<size_t> newRetIndices;
1209  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1210  rewriter, warpOp, additionalResults, additionalResultTypes,
1211  newRetIndices);
1212  rewriter.setInsertionPointAfter(newWarpOp);
1213  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1214 
1215  // 0d extract: The new warp op broadcasts the source vector to all lanes.
1216  // All lanes extract the scalar.
1217  if (is0dOrVec1Extract) {
1218  Value newExtract;
1219  SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1220  newExtract =
1221  rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
1222  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1223  newExtract);
1224  return success();
1225  }
1226 
1227  int64_t staticPos = extractOp.getStaticPosition()[0];
1228  OpFoldResult pos = ShapedType::isDynamic(staticPos)
1229  ? (newWarpOp->getResult(newRetIndices[1]))
1230  : OpFoldResult(rewriter.getIndexAttr(staticPos));
1231  // 1d extract: Distribute the source vector. One lane extracts and shuffles
1232  // the value to all other lanes.
1233  int64_t elementsPerLane = distributedVecType.getShape()[0];
1234  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1235  // tid of extracting thread: pos / elementsPerLane
1236  Value broadcastFromTid = affine::makeComposedAffineApply(
1237  rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1238  // Extract at position: pos % elementsPerLane
1239  Value newPos =
1240  elementsPerLane == 1
1241  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1242  : affine::makeComposedAffineApply(rewriter, loc,
1243  sym0 % elementsPerLane, pos);
1244  Value extracted =
1245  rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
1246 
1247  // Shuffle the extracted value to all lanes.
1248  Value shuffled = warpShuffleFromIdxFn(
1249  loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1250  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1251  return success();
1252  }
1253 
1254 private:
1255  WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1256 };
1257 
1258 /// Pattern to convert vector.extractelement to vector.extract.
1259 struct WarpOpExtractElement : public WarpDistributionPattern {
1260  using Base::Base;
1261  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1262  PatternRewriter &rewriter) const override {
1263  OpOperand *operand =
1264  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1265  if (!operand)
1266  return failure();
1267  auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1268  SmallVector<OpFoldResult> indices;
1269  if (auto pos = extractOp.getPosition()) {
1270  indices.push_back(pos);
1271  }
1272  rewriter.setInsertionPoint(extractOp);
1273  rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1274  extractOp, extractOp.getVector(), indices);
1275  return success();
1276  }
1277 };
1278 
1279 /// Pattern to move out vector.insert with a scalar input.
1280 /// Only supports 1-D and 0-D destinations for now.
1281 struct WarpOpInsertScalar : public WarpDistributionPattern {
1282  using Base::Base;
1283  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1284  PatternRewriter &rewriter) const override {
1285  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1286  if (!operand)
1287  return failure();
1288  unsigned int operandNumber = operand->getOperandNumber();
1289  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1290  VectorType vecType = insertOp.getDestVectorType();
1291  VectorType distrType =
1292  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1293 
1294  // Only supports 1-D or 0-D destinations for now.
1295  if (vecType.getRank() > 1) {
1296  return rewriter.notifyMatchFailure(
1297  insertOp, "only 0-D or 1-D source supported for now");
1298  }
1299 
1300  // Yield destination vector, source scalar and position from warp op.
1301  SmallVector<Value> additionalResults{insertOp.getDest(),
1302  insertOp.getSource()};
1303  SmallVector<Type> additionalResultTypes{distrType,
1304  insertOp.getSource().getType()};
1305  additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1306  additionalResultTypes.append(
1307  SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1308 
1309  Location loc = insertOp.getLoc();
1310  SmallVector<size_t> newRetIndices;
1311  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1312  rewriter, warpOp, additionalResults, additionalResultTypes,
1313  newRetIndices);
1314  rewriter.setInsertionPointAfter(newWarpOp);
1315  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1316  Value newSource = newWarpOp->getResult(newRetIndices[1]);
1317  rewriter.setInsertionPointAfter(newWarpOp);
1318 
1319  OpFoldResult pos;
1320  if (vecType.getRank() != 0) {
1321  int64_t staticPos = insertOp.getStaticPosition()[0];
1322  pos = ShapedType::isDynamic(staticPos)
1323  ? (newWarpOp->getResult(newRetIndices[2]))
1324  : OpFoldResult(rewriter.getIndexAttr(staticPos));
1325  }
1326 
1327  // This condition is always true for 0-d vectors.
1328  if (vecType == distrType) {
1329  Value newInsert;
1330  SmallVector<OpFoldResult> indices;
1331  if (pos) {
1332  indices.push_back(pos);
1333  }
1334  newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
1335  distributedVec, indices);
1336  // Broadcast: Simply move the vector.insert op out.
1337  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1338  newInsert);
1339  return success();
1340  }
1341 
1342  // This is a distribution. Only one lane should insert.
1343  int64_t elementsPerLane = distrType.getShape()[0];
1344  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1345  // tid of extracting thread: pos / elementsPerLane
1346  Value insertingLane = affine::makeComposedAffineApply(
1347  rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1348  // Insert position: pos % elementsPerLane
1350  rewriter, loc, sym0 % elementsPerLane, pos);
1351  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1352  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1353  Value newResult =
1354  rewriter
1355  .create<scf::IfOp>(
1356  loc, isInsertingLane,
1357  /*thenBuilder=*/
1358  [&](OpBuilder &builder, Location loc) {
1359  Value newInsert = builder.create<vector::InsertOp>(
1360  loc, newSource, distributedVec, newPos);
1361  builder.create<scf::YieldOp>(loc, newInsert);
1362  },
1363  /*elseBuilder=*/
1364  [&](OpBuilder &builder, Location loc) {
1365  builder.create<scf::YieldOp>(loc, distributedVec);
1366  })
1367  .getResult(0);
1368  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1369  return success();
1370  }
1371 };
1372 
1373 struct WarpOpInsert : public WarpDistributionPattern {
1374  using Base::Base;
1375  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1376  PatternRewriter &rewriter) const override {
1377  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1378  if (!operand)
1379  return failure();
1380  unsigned int operandNumber = operand->getOperandNumber();
1381  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1382  Location loc = insertOp.getLoc();
1383 
1384  // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1385  if (insertOp.getDestVectorType().getRank() <= 1) {
1386  return failure();
1387  }
1388 
1389  // All following cases are 2d or higher dimensional source vectors.
1390 
1391  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1392  // There is no distribution, this is a broadcast. Simply move the insert
1393  // out of the warp op.
1394  SmallVector<size_t> newRetIndices;
1395  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1396  rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1397  {insertOp.getSourceType(), insertOp.getDestVectorType()},
1398  newRetIndices);
1399  rewriter.setInsertionPointAfter(newWarpOp);
1400  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1401  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1402  Value newResult = rewriter.create<vector::InsertOp>(
1403  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1404  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1405  newResult);
1406  return success();
1407  }
1408 
1409  // Find the distributed dimension. There should be exactly one.
1410  auto distrDestType =
1411  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1412  auto yieldedType = cast<VectorType>(operand->get().getType());
1413  int64_t distrDestDim = -1;
1414  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1415  if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1416  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1417  // support distributing multiple dimensions in the future.
1418  assert(distrDestDim == -1 && "found multiple distributed dims");
1419  distrDestDim = i;
1420  }
1421  }
1422  assert(distrDestDim != -1 && "could not find distributed dimension");
1423 
1424  // Compute the distributed source vector type.
1425  VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
1426  SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1427  // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1428  // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1429  // insert a smaller vector<3xf32>.
1430  // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1431  // case, one lane will insert the source vector<96xf32>. The other
1432  // lanes will not do anything.
1433  int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1434  if (distrSrcDim >= 0)
1435  distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1436  auto distrSrcType =
1437  VectorType::get(distrSrcShape, distrDestType.getElementType());
1438 
1439  // Yield source and dest vectors from warp op.
1440  SmallVector<size_t> newRetIndices;
1441  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1442  rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1443  {distrSrcType, distrDestType}, newRetIndices);
1444  rewriter.setInsertionPointAfter(newWarpOp);
1445  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1446  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1447 
1448  // Insert into the distributed vector.
1449  Value newResult;
1450  if (distrSrcDim >= 0) {
1451  // Every lane inserts a small piece.
1452  newResult = rewriter.create<vector::InsertOp>(
1453  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1454  } else {
1455  // One lane inserts the entire source vector.
1456  int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1457  SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1458  SmallVector<int64_t> newPos = getAsIntegers(pos);
1459  // tid of inserting lane: pos / elementsPerLane
1460  Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1461  loc, newPos[distrDestDim] / elementsPerLane);
1462  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1463  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1464  // Insert position: pos % elementsPerLane
1465  newPos[distrDestDim] %= elementsPerLane;
1466  auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1467  Value newInsert = builder.create<vector::InsertOp>(
1468  loc, distributedSrc, distributedDest, newPos);
1469  builder.create<scf::YieldOp>(loc, newInsert);
1470  };
1471  auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1472  builder.create<scf::YieldOp>(loc, distributedDest);
1473  };
1474  newResult = rewriter
1475  .create<scf::IfOp>(loc, isInsertingLane,
1476  /*thenBuilder=*/insertingBuilder,
1477  /*elseBuilder=*/nonInsertingBuilder)
1478  .getResult(0);
1479  }
1480 
1481  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1482  return success();
1483  }
1484 };
1485 
1486 struct WarpOpInsertElement : public WarpDistributionPattern {
1487  using Base::Base;
1488  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1489  PatternRewriter &rewriter) const override {
1490  OpOperand *operand =
1491  getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1492  if (!operand)
1493  return failure();
1494  auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1495  SmallVector<OpFoldResult> indices;
1496  if (auto pos = insertOp.getPosition()) {
1497  indices.push_back(pos);
1498  }
1499  rewriter.setInsertionPoint(insertOp);
1500  rewriter.replaceOpWithNewOp<vector::InsertOp>(
1501  insertOp, insertOp.getSource(), insertOp.getDest(), indices);
1502  return success();
1503  }
1504 };
1505 
1506 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1507 /// the scf.ForOp is the last operation in the region so that it doesn't
1508 /// change the order of execution. This creates a new scf.for region after the
1509 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1510 /// WarpExecuteOnLane0Op region. Example:
1511 /// ```
1512 /// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1513 /// ...
1514 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1515 /// -> (vector<128xf32>) {
1516 /// ...
1517 /// scf.yield %r : vector<128xf32>
1518 /// }
1519 /// gpu.yield %v1 : vector<128xf32>
1520 /// }
1521 /// ```
1522 /// To:
1523 /// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1524 /// ...
1525 /// gpu.yield %v : vector<128xf32>
1526 /// }
1527 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1528 /// -> (vector<4xf32>) {
1529 /// %iw = gpu.warp_execute_on_lane_0(%laneid)
1530 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1531 /// ^bb0(%arg: vector<128xf32>):
1532 /// ...
1533 /// gpu.yield %ir : vector<128xf32>
1534 /// }
1535 /// scf.yield %iw : vector<4xf32>
1536 /// }
1537 /// ```
1538 struct WarpOpScfForOp : public WarpDistributionPattern {
1539 
1540  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1541  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1542  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1543  PatternRewriter &rewriter) const override {
1544  auto yield = cast<gpu::YieldOp>(
1545  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1546  // Only pick up forOp if it is the last op in the region.
1547  Operation *lastNode = yield->getPrevNode();
1548  auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1549  if (!forOp)
1550  return failure();
1551  // Collect Values that come from the warp op but are outside the forOp.
1552  // Those Value needs to be returned by the original warpOp and passed to
1553  // the new op.
1554  llvm::SmallSetVector<Value, 32> escapingValues;
1555  SmallVector<Type> inputTypes;
1556  SmallVector<Type> distTypes;
1558  forOp.getBodyRegion(), [&](OpOperand *operand) {
1559  Operation *parent = operand->get().getParentRegion()->getParentOp();
1560  if (warpOp->isAncestor(parent)) {
1561  if (!escapingValues.insert(operand->get()))
1562  return;
1563  Type distType = operand->get().getType();
1564  if (auto vecType = dyn_cast<VectorType>(distType)) {
1565  AffineMap map = distributionMapFn(operand->get());
1566  distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1567  }
1568  inputTypes.push_back(operand->get().getType());
1569  distTypes.push_back(distType);
1570  }
1571  });
1572 
1573  if (llvm::is_contained(distTypes, Type{}))
1574  return failure();
1575 
1576  SmallVector<size_t> newRetIndices;
1577  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1578  rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1579  newRetIndices);
1580  yield = cast<gpu::YieldOp>(
1581  newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1582 
1583  SmallVector<Value> newOperands;
1584  SmallVector<unsigned> resultIdx;
1585  // Collect all the outputs coming from the forOp.
1586  for (OpOperand &yieldOperand : yield->getOpOperands()) {
1587  if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1588  continue;
1589  auto forResult = cast<OpResult>(yieldOperand.get());
1590  newOperands.push_back(
1591  newWarpOp.getResult(yieldOperand.getOperandNumber()));
1592  yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1593  resultIdx.push_back(yieldOperand.getOperandNumber());
1594  }
1595 
1596  OpBuilder::InsertionGuard g(rewriter);
1597  rewriter.setInsertionPointAfter(newWarpOp);
1598 
1599  // Create a new for op outside the region with a WarpExecuteOnLane0Op
1600  // region inside.
1601  auto newForOp = rewriter.create<scf::ForOp>(
1602  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1603  forOp.getStep(), newOperands);
1604  rewriter.setInsertionPointToStart(newForOp.getBody());
1605 
1606  SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1607  newForOp.getRegionIterArgs().end());
1608  SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1609  forOp.getResultTypes().end());
1610  llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1611  for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1612  warpInput.push_back(newWarpOp.getResult(retIdx));
1613  argIndexMapping[escapingValues[i]] = warpInputType.size();
1614  warpInputType.push_back(inputTypes[i]);
1615  }
1616  auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1617  newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1618  newWarpOp.getWarpSize(), warpInput, warpInputType);
1619 
1620  SmallVector<Value> argMapping;
1621  argMapping.push_back(newForOp.getInductionVar());
1622  for (Value args : innerWarp.getBody()->getArguments()) {
1623  argMapping.push_back(args);
1624  }
1625  argMapping.resize(forOp.getBody()->getNumArguments());
1626  SmallVector<Value> yieldOperands;
1627  for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1628  yieldOperands.push_back(operand);
1629  rewriter.eraseOp(forOp.getBody()->getTerminator());
1630  rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1631  rewriter.setInsertionPointToEnd(innerWarp.getBody());
1632  rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
1633  rewriter.setInsertionPointAfter(innerWarp);
1634  if (!innerWarp.getResults().empty())
1635  rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1636  rewriter.eraseOp(forOp);
1637  // Replace the warpOp result coming from the original ForOp.
1638  for (const auto &res : llvm::enumerate(resultIdx)) {
1639  rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1640  newForOp.getResult(res.index()));
1641  newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1642  }
1643  newForOp.walk([&](Operation *op) {
1644  for (OpOperand &operand : op->getOpOperands()) {
1645  auto it = argIndexMapping.find(operand.get());
1646  if (it == argIndexMapping.end())
1647  continue;
1648  operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1649  }
1650  });
1651 
1652  // Finally, hoist out any now uniform code from the inner warp op.
1653  mlir::vector::moveScalarUniformCode(innerWarp);
1654  return success();
1655  }
1656 
1657 private:
1658  DistributionMapFn distributionMapFn;
1659 };
1660 
1661 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1662 /// The vector is reduced in parallel. Currently limited to vector size
1663 /// matching the warpOp size. E.g.:
1664 /// ```
1665 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1666 /// %0 = "some_def"() : () -> (vector<32xf32>)
1667 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1668 /// gpu.yield %1 : f32
1669 /// }
1670 /// ```
1671 /// is lowered to:
1672 /// ```
1673 /// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1674 /// %1 = "some_def"() : () -> (vector<32xf32>)
1675 /// gpu.yield %1 : vector<32xf32>
1676 /// }
1677 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
1678 /// %r = ("warp.reduction %a")
1679 /// ```
1680 struct WarpOpReduction : public WarpDistributionPattern {
1681  WarpOpReduction(MLIRContext *context,
1682  DistributedReductionFn distributedReductionFn,
1683  PatternBenefit benefit = 1)
1684  : WarpDistributionPattern(context, benefit),
1685  distributedReductionFn(std::move(distributedReductionFn)) {}
1686 
1687  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1688  PatternRewriter &rewriter) const override {
1689  OpOperand *yieldOperand =
1690  getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1691  if (!yieldOperand)
1692  return failure();
1693 
1694  auto reductionOp =
1695  cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1696  auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1697  // Only rank 1 vectors supported.
1698  if (vectorType.getRank() != 1)
1699  return rewriter.notifyMatchFailure(
1700  warpOp, "Only rank 1 reductions can be distributed.");
1701  // Only warp_size-sized vectors supported.
1702  if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1703  return rewriter.notifyMatchFailure(
1704  warpOp, "Reduction vector dimension must match was size.");
1705  if (!reductionOp.getType().isIntOrFloat())
1706  return rewriter.notifyMatchFailure(
1707  warpOp, "Reduction distribution currently only supports floats and "
1708  "integer types.");
1709 
1710  int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1711  // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1712  unsigned operandIndex = yieldOperand->getOperandNumber();
1713  SmallVector<Value> yieldValues = {reductionOp.getVector()};
1714  SmallVector<Type> retTypes = {
1715  VectorType::get({numElements}, reductionOp.getType())};
1716  if (reductionOp.getAcc()) {
1717  yieldValues.push_back(reductionOp.getAcc());
1718  retTypes.push_back(reductionOp.getAcc().getType());
1719  }
1720  SmallVector<size_t> newRetIndices;
1721  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1722  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1723  rewriter.setInsertionPointAfter(newWarpOp);
1724 
1725  // Obtain data to reduce for a single lane.
1726  Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1727  // Distribute and reduce across threads.
1728  Value fullReduce =
1729  distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1730  reductionOp.getKind(), newWarpOp.getWarpSize());
1731  if (reductionOp.getAcc()) {
1732  fullReduce = vector::makeArithReduction(
1733  rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1734  newWarpOp.getResult(newRetIndices[1]));
1735  }
1736  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1737  return success();
1738  }
1739 
1740 private:
1741  DistributedReductionFn distributedReductionFn;
1742 };
1743 
1744 } // namespace
1745 
1749  patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
1750 }
1751 
1752 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1753  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1754  unsigned maxNumElementsToExtract, PatternBenefit benefit) {
1755  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
1756  maxNumElementsToExtract, benefit);
1757 }
1758 
1759 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1760  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1761  const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1762  PatternBenefit readBenefit) {
1763  patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1764  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1765  WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1766  WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1767  WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1768  patterns.getContext(), benefit);
1769  patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
1770  benefit);
1771  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
1772  benefit);
1773 }
1774 
1775 void mlir::vector::populateDistributeReduction(
1777  const DistributedReductionFn &distributedReductionFn,
1778  PatternBenefit benefit) {
1779  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
1780  benefit);
1781 }
1782 
1783 /// Helper to know if an op can be hoisted out of the region.
1784 static bool canBeHoisted(Operation *op,
1785  function_ref<bool(Value)> definedOutside) {
1786  return llvm::all_of(op->getOperands(), definedOutside) &&
1787  isMemoryEffectFree(op) && op->getNumRegions() == 0;
1788 }
1789 
1790 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1791  Block *body = warpOp.getBody();
1792 
1793  // Keep track of the ops we want to hoist.
1794  llvm::SmallSetVector<Operation *, 8> opsToMove;
1795 
1796  // Helper to check if a value is or will be defined outside of the region.
1797  auto isDefinedOutsideOfBody = [&](Value value) {
1798  auto *definingOp = value.getDefiningOp();
1799  return (definingOp && opsToMove.count(definingOp)) ||
1800  warpOp.isDefinedOutsideOfRegion(value);
1801  };
1802 
1803  // Do not use walk here, as we do not want to go into nested regions and hoist
1804  // operations from there.
1805  for (auto &op : body->without_terminator()) {
1806  bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1807  return isa<VectorType>(result.getType());
1808  });
1809  if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1810  opsToMove.insert(&op);
1811  }
1812 
1813  // Move all the ops marked as uniform outside of the region.
1814  for (Operation *op : opsToMove)
1815  op->moveBefore(warpOp);
1816 }
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:964
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:345
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:412
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:95
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:850
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:620
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1144
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2443
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:314
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:40
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:627
This represents an operation in an abstracted form, suitable for use with the builder APIs.