MLIR  22.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinTypes.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVectorExtras.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <utility>
26 
27 using namespace mlir;
28 using namespace mlir::vector;
29 using namespace mlir::gpu;
30 
31 /// Currently the distribution map is implicit based on the vector shape. In the
32 /// future it will be part of the op.
33 /// Example:
34 /// ```
35 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
36 /// ...
37 /// gpu.yield %3 : vector<32x16x64xf32>
38 /// }
39 /// ```
40 /// Would have an implicit map of:
41 /// `(d0, d1, d2) -> (d0, d2)`
42 static AffineMap calculateImplicitMap(VectorType sequentialType,
43  VectorType distributedType) {
45  perm.reserve(1);
46  // Check which dimensions of the sequential type are different than the
47  // dimensions of the distributed type to know the distributed dimensions. Then
48  // associate each distributed dimension to an ID in order.
49  for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50  if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
51  perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
52  }
53  auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
54  distributedType.getContext());
55  return map;
56 }
57 
58 /// Given a sequential and distributed vector type, returns the distributed
59 /// dimension. This function expects that only a single dimension is
60 /// distributed.
61 static int getDistributedDim(VectorType sequentialType,
62  VectorType distributedType) {
63  assert(sequentialType.getRank() == distributedType.getRank() &&
64  "sequential and distributed vector types must have the same rank");
65  int64_t distributedDim = -1;
66  for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
67  if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
68  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
69  // support distributing multiple dimensions in the future.
70  assert(distributedDim == -1 && "found multiple distributed dims");
71  distributedDim = i;
72  }
73  }
74  return distributedDim;
75 }
76 
77 namespace {
78 
79 /// Helper struct to create the load / store operations that permit transit
80 /// through the parallel / sequential and the sequential / parallel boundaries
81 /// when performing `rewriteWarpOpToScfFor`.
82 ///
83 /// The vector distribution dimension is inferred from the vector types.
84 struct DistributedLoadStoreHelper {
85  DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
86  Value laneId, Value zero)
87  : sequentialVal(sequentialVal), distributedVal(distributedVal),
88  laneId(laneId), zero(zero) {
89  sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
90  distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
91  if (sequentialVectorType && distributedVectorType)
92  distributionMap =
93  calculateImplicitMap(sequentialVectorType, distributedVectorType);
94  }
95 
96  Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
97  int64_t distributedSize = distributedVectorType.getDimSize(index);
99  return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
100  ArrayRef<Value>{laneId});
101  }
102 
103  /// Create a store during the process of distributing the
104  /// `vector.warp_execute_on_thread_0` op.
105  /// Vector distribution assumes the following convention regarding the
106  /// temporary buffers that are created to transition values. This **must**
107  /// be properly specified in the `options.warpAllocationFn`:
108  /// 1. scalars of type T transit through a memref<1xT>.
109  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
110  Operation *buildStore(RewriterBase &b, Location loc, Value val,
111  Value buffer) {
112  assert((val == distributedVal || val == sequentialVal) &&
113  "Must store either the preregistered distributed or the "
114  "preregistered sequential value.");
115  // Scalar case can directly use memref.store.
116  if (!isa<VectorType>(val.getType()))
117  return memref::StoreOp::create(b, loc, val, buffer, zero);
118 
119  // Vector case must use vector::TransferWriteOp which will later lower to
120  // vector.store of memref.store depending on further lowerings.
121  int64_t rank = sequentialVectorType.getRank();
122  SmallVector<Value> indices(rank, zero);
123  if (val == distributedVal) {
124  for (auto dimExpr : distributionMap.getResults()) {
125  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126  indices[index] = buildDistributedOffset(b, loc, index);
127  }
128  }
129  SmallVector<bool> inBounds(indices.size(), true);
130  return vector::TransferWriteOp::create(
131  b, loc, val, buffer, indices,
132  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
133  }
134 
135  /// Create a load during the process of distributing the
136  /// `vector.warp_execute_on_thread_0` op.
137  /// Vector distribution assumes the following convention regarding the
138  /// temporary buffers that are created to transition values. This **must**
139  /// be properly specified in the `options.warpAllocationFn`:
140  /// 1. scalars of type T transit through a memref<1xT>.
141  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
142  ///
143  /// When broadcastMode is true, the load is not distributed to account for
144  /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
145  ///
146  /// Example:
147  ///
148  /// ```
149  /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
150  /// gpu.yield %cst : f32
151  /// }
152  /// // Both types are f32. The constant %cst is broadcasted to all lanes.
153  /// ```
154  /// This behavior described in more detail in the documentation of the op.
155  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
156 
157  // Scalar case can directly use memref.store.
158  if (!isa<VectorType>(type))
159  return memref::LoadOp::create(b, loc, buffer, zero);
160 
161  // Other cases must be vector atm.
162  // Vector case must use vector::TransferReadOp which will later lower to
163  // vector.read of memref.read depending on further lowerings.
164  assert((type == distributedVectorType || type == sequentialVectorType) &&
165  "Must store either the preregistered distributed or the "
166  "preregistered sequential type.");
167  SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
168  if (type == distributedVectorType) {
169  for (auto dimExpr : distributionMap.getResults()) {
170  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171  indices[index] = buildDistributedOffset(b, loc, index);
172  }
173  }
174  SmallVector<bool> inBounds(indices.size(), true);
175  return vector::TransferReadOp::create(
176  b, loc, cast<VectorType>(type), buffer, indices,
177  /*padding=*/std::nullopt,
178  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
179  }
180 
181  Value sequentialVal, distributedVal, laneId, zero;
182  VectorType sequentialVectorType, distributedVectorType;
183  AffineMap distributionMap;
184 };
185 
186 } // namespace
187 
188 // Clones `op` into a new operation that takes `operands` and returns
189 // `resultTypes`.
191  Location loc, Operation *op,
192  ArrayRef<Value> operands,
193  ArrayRef<Type> resultTypes) {
194  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
195  op->getAttrs());
196  return rewriter.create(res);
197 }
198 
199 namespace {
200 
201 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
202 /// thread `laneId` executes the entirety of the computation.
203 ///
204 /// After the transformation:
205 /// - the IR within the scf.if op can be thought of as executing sequentially
206 /// (from the point of view of threads along `laneId`).
207 /// - the IR outside of the scf.if op can be thought of as executing in
208 /// parallel (from the point of view of threads along `laneId`).
209 ///
210 /// Values that need to transit through the parallel / sequential and the
211 /// sequential / parallel boundaries do so via reads and writes to a temporary
212 /// memory location.
213 ///
214 /// The transformation proceeds in multiple steps:
215 /// 1. Create the scf.if op.
216 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
217 /// within the scf.if to transit the values captured from above.
218 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
219 /// consistent within the scf.if.
220 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
221 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to
222 /// transit the values returned by the op.
223 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
224 /// consistent after the scf.if.
225 /// 7. Perform late cleanups.
226 ///
227 /// All this assumes the vector distribution occurs along the most minor
228 /// distributed vector dimension.
229 struct WarpOpToScfIfPattern : public WarpDistributionPattern {
230  WarpOpToScfIfPattern(MLIRContext *context,
232  PatternBenefit benefit = 1)
233  : WarpDistributionPattern(context, benefit), options(options) {}
234 
235  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
236  PatternRewriter &rewriter) const override {
237  assert(warpOp.getBodyRegion().hasOneBlock() &&
238  "expected WarpOp with single block");
239  Block *warpOpBody = &warpOp.getBodyRegion().front();
240  Location loc = warpOp.getLoc();
241 
242  // Passed all checks. Start rewriting.
243  OpBuilder::InsertionGuard g(rewriter);
244  rewriter.setInsertionPoint(warpOp);
245 
246  // Step 1: Create scf.if op.
247  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
248  Value isLane0 = arith::CmpIOp::create(
249  rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
250  auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
251  /*withElseRegion=*/false);
252  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
253 
254  // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
255  // reads within the scf.if to transit the values captured from above.
256  SmallVector<Value> bbArgReplacements;
257  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
258  Value sequentialVal = warpOpBody->getArgument(it.index());
259  Value distributedVal = it.value();
260  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
261  warpOp.getLaneid(), c0);
262 
263  // Create buffer before the ifOp.
264  rewriter.setInsertionPoint(ifOp);
265  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
266  sequentialVal.getType());
267  // Store distributed vector into buffer, before the ifOp.
268  helper.buildStore(rewriter, loc, distributedVal, buffer);
269  // Load sequential vector from buffer, inside the ifOp.
270  rewriter.setInsertionPointToStart(ifOp.thenBlock());
271  bbArgReplacements.push_back(
272  helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
273  }
274 
275  // Step 3. Insert sync after all the stores and before all the loads.
276  if (!warpOp.getArgs().empty()) {
277  rewriter.setInsertionPoint(ifOp);
278  options.warpSyncronizationFn(loc, rewriter, warpOp);
279  }
280 
281  // Step 4. Move body of warpOp to ifOp.
282  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
283 
284  // Step 5. Insert appropriate writes within scf.if and reads after the
285  // scf.if to transit the values returned by the op.
286  // TODO: at this point, we can reuse the shared memory from previous
287  // buffers.
288  SmallVector<Value> replacements;
289  auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
290  Location yieldLoc = yieldOp.getLoc();
291  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
292  Value sequentialVal = it.value();
293  Value distributedVal = warpOp->getResult(it.index());
294  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
295  warpOp.getLaneid(), c0);
296 
297  // Create buffer before the ifOp.
298  rewriter.setInsertionPoint(ifOp);
299  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
300  sequentialVal.getType());
301 
302  // Store yielded value into buffer, inside the ifOp, before the
303  // terminator.
304  rewriter.setInsertionPoint(yieldOp);
305  helper.buildStore(rewriter, loc, sequentialVal, buffer);
306 
307  // Load distributed value from buffer, after the warpOp.
308  rewriter.setInsertionPointAfter(ifOp);
309  // Result type and yielded value type are the same. This is a broadcast.
310  // E.g.:
311  // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
312  // gpu.yield %cst : f32
313  // }
314  // Both types are f32. The constant %cst is broadcasted to all lanes.
315  // This is described in more detail in the documentation of the op.
316  replacements.push_back(
317  helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
318  }
319 
320  // Step 6. Insert sync after all the stores and before all the loads.
321  if (!yieldOp.getOperands().empty()) {
322  rewriter.setInsertionPointAfter(ifOp);
323  options.warpSyncronizationFn(loc, rewriter, warpOp);
324  }
325 
326  // Step 7. Delete terminator and add empty scf.yield.
327  rewriter.eraseOp(yieldOp);
328  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
329  scf::YieldOp::create(rewriter, yieldLoc);
330 
331  // Compute replacements for WarpOp results.
332  rewriter.replaceOp(warpOp, replacements);
333 
334  return success();
335  }
336 
337 private:
339 };
340 
341 /// Return the distributed vector type based on the original type and the
342 /// distribution map. The map is expected to have a dimension equal to the
343 /// original type rank and should be a projection where the results are the
344 /// distributed dimensions. The number of results should be equal to the number
345 /// of warp sizes which is currently limited to 1.
346 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
347 /// and a warp size of 16 would distribute the second dimension (associated to
348 /// d1) and return vector<16x2x64>
349 static VectorType getDistributedType(VectorType originalType, AffineMap map,
350  int64_t warpSize) {
351  SmallVector<int64_t> targetShape(originalType.getShape());
352  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
353  unsigned position = map.getDimPosition(i);
354  if (targetShape[position] % warpSize != 0) {
355  if (warpSize % targetShape[position] != 0) {
356  return VectorType();
357  }
358  warpSize /= targetShape[position];
359  targetShape[position] = 1;
360  continue;
361  }
362  targetShape[position] = targetShape[position] / warpSize;
363  warpSize = 1;
364  break;
365  }
366  if (warpSize != 1) {
367  return VectorType();
368  }
369  VectorType targetType =
370  VectorType::get(targetShape, originalType.getElementType());
371  return targetType;
372 }
373 
374 /// Distribute transfer_write ops based on the affine map returned by
375 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
376 /// will not be distributed (it should be less than the warp size).
377 ///
378 /// Example:
379 /// ```
380 /// %0 = gpu.warp_execute_on_lane_0(%id){
381 /// ...
382 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
383 /// gpu.yield
384 /// }
385 /// ```
386 /// To
387 /// ```
388 /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
389 /// ...
390 /// gpu.yield %v : vector<32xf32>
391 /// }
392 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
393 struct WarpOpTransferWrite : public WarpDistributionPattern {
394  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
395  unsigned maxNumElementsToExtract, PatternBenefit b = 1)
396  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
397  maxNumElementsToExtract(maxNumElementsToExtract) {}
398 
399  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
400  /// are multiples of the distribution ratio are supported at the moment.
401  LogicalResult tryDistributeOp(RewriterBase &rewriter,
402  vector::TransferWriteOp writeOp,
403  WarpExecuteOnLane0Op warpOp) const {
404  VectorType writtenVectorType = writeOp.getVectorType();
405 
406  // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
407  // to separate it from the rest.
408  if (writtenVectorType.getRank() == 0)
409  return failure();
410 
411  // 2. Compute the distributed type.
412  AffineMap map = distributionMapFn(writeOp.getVector());
413  VectorType targetType =
414  getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
415  if (!targetType)
416  return failure();
417 
418  // 2.5 Compute the distributed type for the new mask;
419  VectorType maskType;
420  if (writeOp.getMask()) {
421  // TODO: Distribution of masked writes with non-trivial permutation maps
422  // requires the distribution of the mask to elementwise match the
423  // distribution of the permuted written vector. Currently the details
424  // of which lane is responsible for which element is captured strictly
425  // by shape information on the warp op, and thus requires materializing
426  // the permutation in IR.
427  if (!writeOp.getPermutationMap().isMinorIdentity())
428  return failure();
429  maskType =
430  getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
431  }
432 
433  // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
434  // the rest.
435  vector::TransferWriteOp newWriteOp =
436  cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
437 
438  // 4. Reindex the write using the distribution map.
439  auto newWarpOp =
440  newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
441 
442  // Delinearize the lane id based on the way threads are divided across the
443  // vector. To get the number of threads per vector dimension, divide the
444  // sequential size by the distributed size along each dim.
445  rewriter.setInsertionPoint(newWriteOp);
446  SmallVector<OpFoldResult> delinearizedIdSizes;
447  for (auto [seqSize, distSize] :
448  llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
449  assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
450  delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
451  }
452  SmallVector<Value> delinearized;
453  if (map.getNumResults() > 1) {
454  delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
455  rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
456  delinearizedIdSizes)
457  .getResults();
458  } else {
459  // If there is only one map result, we can elide the delinearization
460  // op and use the lane id directly.
461  delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
462  }
463 
464  AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
465  Location loc = newWriteOp.getLoc();
466  SmallVector<Value> indices(newWriteOp.getIndices().begin(),
467  newWriteOp.getIndices().end());
468  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
469  AffineExpr d0, d1;
470  bindDims(newWarpOp.getContext(), d0, d1);
471  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
472  if (!indexExpr)
473  continue;
474  unsigned indexPos = indexExpr.getPosition();
475  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
476  Value laneId = delinearized[vectorPos];
477  auto scale =
478  rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
479  indices[indexPos] = affine::makeComposedAffineApply(
480  rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
481  }
482  newWriteOp.getIndicesMutable().assign(indices);
483 
484  return success();
485  }
486 
487  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
488  LogicalResult tryExtractOp(RewriterBase &rewriter,
489  vector::TransferWriteOp writeOp,
490  WarpExecuteOnLane0Op warpOp) const {
491  Location loc = writeOp.getLoc();
492  VectorType vecType = writeOp.getVectorType();
493 
494  if (vecType.getNumElements() > maxNumElementsToExtract) {
495  return rewriter.notifyMatchFailure(
496  warpOp,
497  llvm::formatv(
498  "writes more elements ({0}) than allowed to extract ({1})",
499  vecType.getNumElements(), maxNumElementsToExtract));
500  }
501 
502  // Do not process warp ops that contain only TransferWriteOps.
503  if (llvm::all_of(warpOp.getOps(),
504  llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
505  return failure();
506 
507  SmallVector<Value> yieldValues = {writeOp.getVector()};
508  SmallVector<Type> retTypes = {vecType};
509  SmallVector<size_t> newRetIndices;
510  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
511  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
512  rewriter.setInsertionPointAfter(newWarpOp);
513 
514  // Create a second warp op that contains only writeOp.
515  auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc, TypeRange(),
516  newWarpOp.getLaneid(),
517  newWarpOp.getWarpSize());
518  Block &body = secondWarpOp.getBodyRegion().front();
519  rewriter.setInsertionPointToStart(&body);
520  auto newWriteOp =
521  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
522  newWriteOp.getValueToStoreMutable().assign(
523  newWarpOp.getResult(newRetIndices[0]));
524  rewriter.eraseOp(writeOp);
525  gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
526  return success();
527  }
528 
529  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
530  PatternRewriter &rewriter) const override {
531  gpu::YieldOp yield = warpOp.getTerminator();
532  Operation *lastNode = yield->getPrevNode();
533  auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
534  if (!writeOp)
535  return failure();
536 
537  Value maybeMask = writeOp.getMask();
538  if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
539  return writeOp.getVector() == value ||
540  (maybeMask && maybeMask == value) ||
541  warpOp.isDefinedOutsideOfRegion(value);
542  }))
543  return failure();
544 
545  if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
546  return success();
547 
548  // Masked writes not supported for extraction.
549  if (writeOp.getMask())
550  return failure();
551 
552  if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
553  return success();
554 
555  return failure();
556  }
557 
558 private:
559  /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
560  /// execute op with the proper return type. The new write op is updated to
561  /// write the result of the new warp execute op. The old `writeOp` is deleted.
562  vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
563  WarpExecuteOnLane0Op warpOp,
564  vector::TransferWriteOp writeOp,
565  VectorType targetType,
566  VectorType maybeMaskType) const {
567  assert(writeOp->getParentOp() == warpOp &&
568  "write must be nested immediately under warp");
569  OpBuilder::InsertionGuard g(rewriter);
570  SmallVector<size_t> newRetIndices;
571  WarpExecuteOnLane0Op newWarpOp;
572  if (maybeMaskType) {
573  newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
574  rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
575  TypeRange{targetType, maybeMaskType}, newRetIndices);
576  } else {
577  newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
578  rewriter, warpOp, ValueRange{{writeOp.getVector()}},
579  TypeRange{targetType}, newRetIndices);
580  }
581  rewriter.setInsertionPointAfter(newWarpOp);
582  auto newWriteOp =
583  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
584  rewriter.eraseOp(writeOp);
585  newWriteOp.getValueToStoreMutable().assign(
586  newWarpOp.getResult(newRetIndices[0]));
587  if (maybeMaskType)
588  newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
589  return newWriteOp;
590  }
591 
592  DistributionMapFn distributionMapFn;
593  unsigned maxNumElementsToExtract = 1;
594 };
595 
596 /// Sink out elementwise op feeding into a warp op yield.
597 /// ```
598 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
599 /// ...
600 /// %3 = arith.addf %1, %2 : vector<32xf32>
601 /// gpu.yield %3 : vector<32xf32>
602 /// }
603 /// ```
604 /// To
605 /// ```
606 /// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
607 /// vector<1xf32>, vector<1xf32>) {
608 /// ...
609 /// %4 = arith.addf %2, %3 : vector<32xf32>
610 /// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
611 /// vector<32xf32>
612 /// }
613 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
614 struct WarpOpElementwise : public WarpDistributionPattern {
615  using Base::Base;
616  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
617  PatternRewriter &rewriter) const override {
618  OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
620  });
621  if (!yieldOperand)
622  return failure();
623 
624  Operation *elementWise = yieldOperand->get().getDefiningOp();
625  unsigned operandIndex = yieldOperand->getOperandNumber();
626  Value distributedVal = warpOp.getResult(operandIndex);
627  SmallVector<Value> yieldValues;
628  SmallVector<Type> retTypes;
629  Location loc = warpOp.getLoc();
630  for (OpOperand &operand : elementWise->getOpOperands()) {
631  Type targetType;
632  if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
633  // If the result type is a vector, the operands must also be vectors.
634  auto operandType = cast<VectorType>(operand.get().getType());
635  targetType =
636  VectorType::get(vecType.getShape(), operandType.getElementType());
637  } else {
638  auto operandType = operand.get().getType();
639  assert(!isa<VectorType>(operandType) &&
640  "unexpected yield of vector from op with scalar result type");
641  targetType = operandType;
642  }
643  retTypes.push_back(targetType);
644  yieldValues.push_back(operand.get());
645  }
646  SmallVector<size_t> newRetIndices;
647  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
648  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
649  rewriter.setInsertionPointAfter(newWarpOp);
650  SmallVector<Value> newOperands(elementWise->getOperands().begin(),
651  elementWise->getOperands().end());
652  for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
653  newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
654  }
655  OpBuilder::InsertionGuard g(rewriter);
656  rewriter.setInsertionPointAfter(newWarpOp);
658  rewriter, loc, elementWise, newOperands,
659  {newWarpOp.getResult(operandIndex).getType()});
660  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
661  newOp->getResult(0));
662  return success();
663  }
664 };
665 
666 /// Sink out splat constant op feeding into a warp op yield.
667 /// ```
668 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
669 /// ...
670 /// %cst = arith.constant dense<2.0> : vector<32xf32>
671 /// gpu.yield %cst : vector<32xf32>
672 /// }
673 /// ```
674 /// To
675 /// ```
676 /// gpu.warp_execute_on_lane_0(%arg0 {
677 /// ...
678 /// }
679 /// %0 = arith.constant dense<2.0> : vector<1xf32>
680 struct WarpOpConstant : public WarpDistributionPattern {
681  using Base::Base;
682  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
683  PatternRewriter &rewriter) const override {
684  OpOperand *yieldOperand =
685  getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
686  if (!yieldOperand)
687  return failure();
688  auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
689  auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
690  if (!dense)
691  return failure();
692  // Notify the rewriter that the warp op is changing (see the comment on
693  // the WarpOpTransferRead pattern).
694  rewriter.startOpModification(warpOp);
695  unsigned operandIndex = yieldOperand->getOperandNumber();
696  Attribute scalarAttr = dense.getSplatValue<Attribute>();
697  auto newAttr = DenseElementsAttr::get(
698  cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
699  Location loc = warpOp.getLoc();
700  rewriter.setInsertionPointAfter(warpOp);
701  Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
702  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
703  rewriter.finalizeOpModification(warpOp);
704  return success();
705  }
706 };
707 
708 /// Sink out step op feeding into a warp op yield.
709 /// Vector step op is treated similar to arith.constant, apart from
710 /// the result that represents a sequence [0, vec_size).
711 /// Due to the to vec_size == warp_size limitation,
712 /// we can simply wrap the lane id into a vector (i.e., broadcast).
713 /// Supporting vec_size != warp_size may involve preserving the step
714 /// result and using additional arith ops (the exact details are TBD).
715 /// ```
716 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
717 /// ...
718 /// %cst = vector.step : vector<32xindex>
719 /// gpu.yield %cst : vector<1xindex>
720 /// }
721 /// ```
722 /// To
723 /// ```
724 /// gpu.warp_execute_on_lane_0(%arg0) {
725 /// ...
726 /// }
727 /// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
728 struct WarpOpStep final : public WarpDistributionPattern {
729  using Base::Base;
730  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
731  PatternRewriter &rewriter) const override {
732  OpOperand *yieldOperand =
733  getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
734  if (!yieldOperand)
735  return failure();
736  const unsigned operandIdx = yieldOperand->getOperandNumber();
737  auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
738  VectorType resTy = stepOp.getResult().getType();
739  if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize()))
740  return rewriter.notifyMatchFailure(
741  warpOp,
742  llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
743  resTy.getNumElements(), warpOp.getWarpSize()));
744  VectorType newVecTy =
745  cast<VectorType>(warpOp.getResult(operandIdx).getType());
746  rewriter.setInsertionPointAfter(warpOp);
747  Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
748  newVecTy, warpOp.getLaneid());
749  rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
750  return success();
751  }
752 };
753 
754 /// Sink out transfer_read op feeding into a warp op yield.
755 /// ```
756 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
757 /// ...
758 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
759 // vector<32xf32>
760 /// gpu.yield %2 : vector<32xf32>
761 /// }
762 /// ```
763 /// To
764 /// ```
765 /// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
766 /// vector<1xf32>, vector<1xf32>) {
767 /// ...
768 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
769 /// vector<32xf32> gpu.yield %2 : vector<32xf32>
770 /// }
771 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
772 struct WarpOpTransferRead : public WarpDistributionPattern {
773  using Base::Base;
774  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
775  PatternRewriter &rewriter) const override {
776  // Try to find a distributable yielded read. Note that this pattern can
777  // still fail at the end after distribution, in which case this might have
778  // missed another distributable read.
779  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
780  // Don't duplicate transfer_read ops when distributing.
781  return isa<vector::TransferReadOp>(op) && op->hasOneUse();
782  });
783  if (!operand)
784  return rewriter.notifyMatchFailure(
785  warpOp, "warp result is not a vector.transfer_read op");
786  auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
787 
788  // Source must be defined outside of the region.
789  if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
790  return rewriter.notifyMatchFailure(
791  read, "source must be defined outside of the region");
792 
793  unsigned operandIndex = operand->getOperandNumber();
794  Value distributedVal = warpOp.getResult(operandIndex);
795 
796  SmallVector<Value, 4> indices(read.getIndices().begin(),
797  read.getIndices().end());
798  auto sequentialType = cast<VectorType>(read.getResult().getType());
799  auto distributedType = cast<VectorType>(distributedVal.getType());
800  AffineMap map = calculateImplicitMap(sequentialType, distributedType);
801  AffineMap indexMap = map.compose(read.getPermutationMap());
802 
803  // Try to delinearize the lane ID to match the rank expected for
804  // distribution.
805  SmallVector<Value> delinearizedIds;
806  if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
807  distributedType.getShape(), warpOp.getWarpSize(),
808  warpOp.getLaneid(), delinearizedIds)) {
809  return rewriter.notifyMatchFailure(
810  read, "cannot delinearize lane ID for distribution");
811  }
812  assert(!delinearizedIds.empty() || map.getNumResults() == 0);
813 
814  // Distribute indices and the mask (if present).
815  OpBuilder::InsertionGuard g(rewriter);
816  SmallVector<Value> additionalResults(indices.begin(), indices.end());
817  SmallVector<Type> additionalResultTypes(indices.size(),
818  rewriter.getIndexType());
819  additionalResults.push_back(read.getPadding());
820  additionalResultTypes.push_back(read.getPadding().getType());
821 
822  bool hasMask = false;
823  if (read.getMask()) {
824  hasMask = true;
825  // TODO: Distribution of masked reads with non-trivial permutation maps
826  // requires the distribution of the mask to elementwise match the
827  // distribution of the permuted written vector. Currently the details
828  // of which lane is responsible for which element is captured strictly
829  // by shape information on the warp op, and thus requires materializing
830  // the permutation in IR.
831  if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
832  return rewriter.notifyMatchFailure(
833  read, "non-trivial permutation maps not supported");
834  VectorType maskType =
835  getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
836  additionalResults.push_back(read.getMask());
837  additionalResultTypes.push_back(maskType);
838  }
839 
840  SmallVector<size_t> newRetIndices;
841  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
842  rewriter, warpOp, additionalResults, additionalResultTypes,
843  newRetIndices);
844  distributedVal = newWarpOp.getResult(operandIndex);
845 
846  // Distributed indices were appended first.
847  SmallVector<Value> newIndices;
848  for (int64_t i = 0, e = indices.size(); i < e; ++i)
849  newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
850 
851  rewriter.setInsertionPointAfter(newWarpOp);
852  for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
853  AffineExpr d0, d1;
854  bindDims(read.getContext(), d0, d1);
855  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
856  if (!indexExpr)
857  continue;
858  unsigned indexPos = indexExpr.getPosition();
859  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
860  int64_t scale = distributedType.getDimSize(vectorPos);
861  newIndices[indexPos] = affine::makeComposedAffineApply(
862  rewriter, read.getLoc(), d0 + scale * d1,
863  {newIndices[indexPos], delinearizedIds[vectorPos]});
864  }
865 
866  // Distributed padding value was appended right after the indices.
867  Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
868  // Distributed mask value was added at the end (if the op has a mask).
869  Value newMask =
870  hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
871  : Value();
872  auto newRead = vector::TransferReadOp::create(
873  rewriter, read.getLoc(), distributedVal.getType(), read.getBase(),
874  newIndices, read.getPermutationMapAttr(), newPadding, newMask,
875  read.getInBoundsAttr());
876 
877  rewriter.replaceAllUsesWith(distributedVal, newRead);
878  return success();
879  }
880 };
881 
882 /// Remove any result that has no use along with the matching yieldOp operand.
883 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
884 struct WarpOpDeadResult : public WarpDistributionPattern {
885  using Base::Base;
886  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
887  PatternRewriter &rewriter) const override {
888  SmallVector<Type> newResultTypes;
889  newResultTypes.reserve(warpOp->getNumResults());
890  SmallVector<Value> newYieldValues;
891  newYieldValues.reserve(warpOp->getNumResults());
892  DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
893  DenseMap<OpResult, int64_t> dedupResultPositionMap;
894  gpu::YieldOp yield = warpOp.getTerminator();
895 
896  // Some values may be yielded multiple times and correspond to multiple
897  // results. Deduplicating occurs by taking each result with its matching
898  // yielded value, and:
899  // 1. recording the unique first position at which the value is yielded.
900  // 2. recording for the result, the first position at which the dedup'ed
901  // value is yielded.
902  // 3. skipping from the new result types / new yielded values any result
903  // that has no use or whose yielded value has already been seen.
904  for (OpResult result : warpOp.getResults()) {
905  Value yieldOperand = yield.getOperand(result.getResultNumber());
906  auto it = dedupYieldOperandPositionMap.insert(
907  std::make_pair(yieldOperand, newResultTypes.size()));
908  dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
909  if (result.use_empty() || !it.second)
910  continue;
911  newResultTypes.push_back(result.getType());
912  newYieldValues.push_back(yieldOperand);
913  }
914  // No modification, exit early.
915  if (yield.getNumOperands() == newYieldValues.size())
916  return failure();
917  // Move the body of the old warpOp to a new warpOp.
918  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
919  rewriter, warpOp, newYieldValues, newResultTypes);
920 
921  // Simplify the new warp op after dropping dead results.
922  newWarpOp.getBody()->walk([&](Operation *op) {
923  if (isOpTriviallyDead(op))
924  rewriter.eraseOp(op);
925  });
926 
927  // Replace results of the old warpOp by the new, deduplicated results.
928  SmallVector<Value> newValues;
929  newValues.reserve(warpOp->getNumResults());
930  for (OpResult result : warpOp.getResults()) {
931  if (result.use_empty())
932  newValues.push_back(Value());
933  else
934  newValues.push_back(
935  newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
936  }
937  rewriter.replaceOp(warpOp, newValues);
938  return success();
939  }
940 };
941 
942 // If an operand is directly yielded out of the region we can forward it
943 // directly and it doesn't need to go through the region.
944 struct WarpOpForwardOperand : public WarpDistributionPattern {
945  using Base::Base;
946  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
947  PatternRewriter &rewriter) const override {
948  gpu::YieldOp yield = warpOp.getTerminator();
949  Value valForwarded;
950  unsigned resultIndex;
951  for (OpOperand &operand : yield->getOpOperands()) {
952  Value result = warpOp.getResult(operand.getOperandNumber());
953  if (result.use_empty())
954  continue;
955 
956  // Assume all the values coming from above are uniform.
957  if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
958  if (result.getType() != operand.get().getType())
959  continue;
960  valForwarded = operand.get();
961  resultIndex = operand.getOperandNumber();
962  break;
963  }
964  auto arg = dyn_cast<BlockArgument>(operand.get());
965  if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
966  continue;
967  Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
968  if (result.getType() != warpOperand.getType())
969  continue;
970  valForwarded = warpOperand;
971  resultIndex = operand.getOperandNumber();
972  break;
973  }
974  if (!valForwarded)
975  return failure();
976  // Notify the rewriter that the warp op is changing (see the comment on
977  // the WarpOpTransferRead pattern).
978  rewriter.startOpModification(warpOp);
979  rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
980  rewriter.finalizeOpModification(warpOp);
981  return success();
982  }
983 };
984 
985 struct WarpOpBroadcast : public WarpDistributionPattern {
986  using Base::Base;
987  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
988  PatternRewriter &rewriter) const override {
989  OpOperand *operand =
990  getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
991  if (!operand)
992  return failure();
993  unsigned int operandNumber = operand->getOperandNumber();
994  auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
995  Location loc = broadcastOp.getLoc();
996  auto destVecType =
997  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
998  Value broadcastSrc = broadcastOp.getSource();
999  Type broadcastSrcType = broadcastSrc.getType();
1000 
1001  // Check that the broadcast actually spans a set of values uniformly across
1002  // all threads. In other words, check that each thread can reconstruct
1003  // their own broadcast.
1004  // For that we simply check that the broadcast we want to build makes sense.
1005  if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
1007  return failure();
1008  SmallVector<size_t> newRetIndices;
1009  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1010  rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1011  rewriter.setInsertionPointAfter(newWarpOp);
1012  Value broadcasted = vector::BroadcastOp::create(
1013  rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1014  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1015  broadcasted);
1016  return success();
1017  }
1018 };
1019 
1020 /// Pattern to move shape cast out of the warp op. shape cast is basically a
1021 /// no-op for warp distribution; we need to handle the shape though.
1022 struct WarpOpShapeCast : public WarpDistributionPattern {
1023  using Base::Base;
1024  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1025  PatternRewriter &rewriter) const override {
1026  OpOperand *operand =
1027  getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1028  if (!operand)
1029  return failure();
1030 
1031  auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
1032 
1033  unsigned int operandNumber = operand->getOperandNumber();
1034  auto castDistributedType =
1035  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1036  VectorType castOriginalType = oldCastOp.getSourceVectorType();
1037  VectorType castResultType = castDistributedType;
1038 
1039  // We expect the distributed type to have a smaller rank than the original
1040  // type. Prepend with size-one dimensions to make them the same.
1041  unsigned castDistributedRank = castDistributedType.getRank();
1042  unsigned castOriginalRank = castOriginalType.getRank();
1043  if (castDistributedRank < castOriginalRank) {
1044  SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1045  llvm::append_range(shape, castDistributedType.getShape());
1046  castDistributedType =
1047  VectorType::get(shape, castDistributedType.getElementType());
1048  }
1049 
1050  SmallVector<size_t> newRetIndices;
1051  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1052  rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1053  newRetIndices);
1054  rewriter.setInsertionPointAfter(newWarpOp);
1055  Value newCast = vector::ShapeCastOp::create(
1056  rewriter, oldCastOp.getLoc(), castResultType,
1057  newWarpOp->getResult(newRetIndices[0]));
1058  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1059  return success();
1060  }
1061 };
1062 
1063 /// Sink out vector.create_mask op feeding into a warp op yield.
1064 /// ```
1065 /// %0 = ...
1066 /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1067 /// ...
1068 /// %mask = vector.create_mask %0 : vector<32xi1>
1069 /// gpu.yield %mask : vector<32xi1>
1070 /// }
1071 /// ```
1072 /// To
1073 /// ```
1074 /// %0 = ...
1075 /// gpu.warp_execute_on_lane_0(%arg0) {
1076 /// ...
1077 /// }
1078 /// %cmp = arith.cmpi ult, %laneid, %0
1079 /// %ub = arith.select %cmp, %c0, %c1
1080 /// %1 = vector.create_mask %ub : vector<1xi1>
1081 struct WarpOpCreateMask : public WarpDistributionPattern {
1082  using Base::Base;
1083  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1084  PatternRewriter &rewriter) const override {
1085  OpOperand *yieldOperand =
1086  getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1087  if (!yieldOperand)
1088  return failure();
1089 
1090  auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1091 
1092  // Early exit if any values needed for calculating the new mask indices
1093  // are defined inside the warp op.
1094  if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1095  return warpOp.isDefinedOutsideOfRegion(value);
1096  }))
1097  return failure();
1098 
1099  Location loc = mask.getLoc();
1100  unsigned operandIndex = yieldOperand->getOperandNumber();
1101 
1102  auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1103  VectorType seqType = mask.getVectorType();
1104  ArrayRef<int64_t> seqShape = seqType.getShape();
1105  ArrayRef<int64_t> distShape = distType.getShape();
1106 
1107  rewriter.setInsertionPointAfter(warpOp);
1108 
1109  // Delinearize the lane ID for constructing the distributed mask sizes.
1110  SmallVector<Value> delinearizedIds;
1111  if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1112  warpOp.getWarpSize(), warpOp.getLaneid(),
1113  delinearizedIds))
1114  return rewriter.notifyMatchFailure(
1115  mask, "cannot delinearize lane ID for distribution");
1116  assert(!delinearizedIds.empty());
1117 
1118  // Notify the rewriter that the warp op is changing (see the comment on
1119  // the WarpOpTransferRead pattern).
1120  rewriter.startOpModification(warpOp);
1121 
1122  AffineExpr s0, s1;
1123  bindSymbols(rewriter.getContext(), s0, s1);
1124  SmallVector<Value> newOperands;
1125  for (int i = 0, e = distShape.size(); i < e; ++i) {
1126  // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1127  // find the distance from the largest mask index owned by this lane to the
1128  // original mask size. `vector.create_mask` implicitly clamps mask
1129  // operands to the range [0, mask_vector_size[i]], or in other words, the
1130  // mask sizes are always in the range [0, mask_vector_size[i]).
1132  rewriter, loc, s1 - s0 * distShape[i],
1133  {delinearizedIds[i], mask.getOperand(i)});
1134  newOperands.push_back(maskDimIdx);
1135  }
1136 
1137  auto newMask =
1138  vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1139  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1140  rewriter.finalizeOpModification(warpOp);
1141  return success();
1142  }
1143 };
1144 
1145 /// Sink out insert_strided_slice op feeding into a warp op yield.
1146 /// ```
1147 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
1148 /// ...
1149 /// %src = ... : vector<4x32xf32>
1150 /// %dest = ... : vector<8x32xf32>
1151 /// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1152 /// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
1153 /// gpu.yield %insert : vector<8x32xf32>
1154 /// }
1155 /// ```
1156 /// To
1157 /// ```
1158 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
1159 /// vector<8x1xf32>) {
1160 /// ...
1161 /// %src = ... : vector<4x32xf32>
1162 /// %dest = ... : vector<8x32xf32>
1163 /// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
1164 /// }
1165 /// %insert = vector.insert_strided_slice %0#0, %0#1,
1166 /// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
1167 /// ```
1168 /// NOTE: Current support assumes that both src and dest vectors are distributed
1169 /// to lanes and sinking the insert op does not require any cross lane
1170 /// communication.
1171 struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
1172  using Base::Base;
1173  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1174  PatternRewriter &rewriter) const override {
1175  OpOperand *operand =
1176  getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1177  if (!operand)
1178  return failure();
1179  unsigned int operandNumber = operand->getOperandNumber();
1180  auto insertOp =
1181  operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1182  auto distributedType =
1183  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1184  // Distributed type must be 2D or higher.
1185  // TODO: Support 1D distributed types.
1186  if (distributedType.getRank() < 2)
1187  return rewriter.notifyMatchFailure(
1188  insertOp, "result vector type must be 2D or higher");
1189  // Find the distributed dimension of the dest vector. There should be
1190  // exactly one.
1191  auto yieldedType = cast<VectorType>(operand->get().getType());
1192  int64_t destDistributedDim =
1193  getDistributedDim(yieldedType, distributedType);
1194  assert(destDistributedDim != -1 && "could not find distributed dimension");
1195 
1196  VectorType srcType = insertOp.getSourceVectorType();
1197  VectorType destType = insertOp.getDestVectorType();
1198  // Currently we require that both source (kD) and dest (nD) vectors are
1199  // distributed. This requires that distributedDim (d) is contained in the
1200  // last k dims of the dest vector (d >= n - k).
1201  // TODO: Add support for case where source vector is not distributed.
1202  int64_t sourceDistributedDim =
1203  destDistributedDim - (destType.getRank() - srcType.getRank());
1204  if (sourceDistributedDim < 0)
1205  return rewriter.notifyMatchFailure(
1206  insertOp,
1207  "distributed dimension must be in the last k dims of dest vector");
1208  // Distributed dimension must be fully inserted.
1209  if (srcType.getDimSize(sourceDistributedDim) !=
1210  destType.getDimSize(destDistributedDim))
1211  return rewriter.notifyMatchFailure(
1212  insertOp, "distributed dimension must be fully inserted");
1213  SmallVector<int64_t> newSourceDistShape(
1214  insertOp.getSourceVectorType().getShape());
1215  newSourceDistShape[sourceDistributedDim] =
1216  distributedType.getDimSize(destDistributedDim);
1217  auto newSourceTy =
1218  VectorType::get(newSourceDistShape, distributedType.getElementType());
1219  VectorType newDestTy = distributedType;
1220  SmallVector<size_t> newRetIndices;
1221  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1222  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1223  {newSourceTy, newDestTy}, newRetIndices);
1224  rewriter.setInsertionPointAfter(newWarpOp);
1225  Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1226  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1227  // Create a new insert strided slice op that inserts distributed source into
1228  // distributed dest.
1229  Value newInsert = vector::InsertStridedSliceOp::create(
1230  rewriter, insertOp.getLoc(), distributedDest.getType(),
1231  distributedSource, distributedDest, insertOp.getOffsets(),
1232  insertOp.getStrides());
1233  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
1234  return success();
1235  }
1236 };
1237 
1238 /// Sink out extract_strided_slice op feeding into a warp op yield.
1239 /// ```
1240 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
1241 /// ...
1242 /// %src = ... : vector<64x32xf32>
1243 /// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1244 /// strides = [1] : vector<64x32xf32> to vector<16x32xf32>
1245 /// gpu.yield %extract : vector<16x32xf32>
1246 /// }
1247 /// ```
1248 /// To
1249 /// ```
1250 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
1251 /// ...
1252 /// %src = ... : vector<64x32xf32>
1253 /// gpu.yield %src : vector<64x32xf32>
1254 /// }
1255 /// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1256 /// strides = [1] : vector<64x1xf32> to vector<16x1xf32>
1257 /// ```
1258 /// NOTE: Current support assumes that the extraction happens only on non
1259 /// distributed dimensions (does not require cross lane communication).
1260 struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
1261  using Base::Base;
1262  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1263  PatternRewriter &rewriter) const override {
1264  OpOperand *operand =
1265  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1266  if (!operand)
1267  return failure();
1268  unsigned int operandNumber = operand->getOperandNumber();
1269  auto extractOp =
1270  operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
1271  auto distributedType =
1272  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1273  // Distributed type must be 2D or higher.
1274  // TODO: Support 1D distributed types.
1275  if (distributedType.getRank() < 2)
1276  return rewriter.notifyMatchFailure(
1277  extractOp, "result vector type must be 2D or higher");
1278 
1279  // Find the distributed dimension. There should be exactly one.
1280  auto yieldedType = cast<VectorType>(operand->get().getType());
1281  int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1282  assert(distributedDim != -1 && "could not find distributed dimension");
1283 
1284  int64_t numOfExtractedDims =
1285  static_cast<int64_t>(extractOp.getSizes().size());
1286  // If the distributed dim is included in the extracted dims, then we make
1287  // sure distributed dim is fully extracted. If distributed dim is not
1288  // included in extracted dims, it is guaranteed to be fully extracted (i.e.
1289  // distributed dim comes after all the extracted dims)
1290  // TODO: Partial extraction from distributed dimension require cross lane
1291  // communication.
1292  if (distributedDim < numOfExtractedDims) {
1293  int64_t distributedDimOffset =
1294  llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1295  .getInt();
1296  int64_t distributedDimSize =
1297  llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1298  .getInt();
1299  if (distributedDimOffset != 0 ||
1300  distributedDimSize != yieldedType.getDimSize(distributedDim))
1301  return rewriter.notifyMatchFailure(
1302  extractOp, "distributed dimension must be fully extracted");
1303  }
1304  SmallVector<int64_t> newDistributedShape(
1305  extractOp.getSourceVectorType().getShape());
1306  newDistributedShape[distributedDim] =
1307  distributedType.getDimSize(distributedDim);
1308  auto newDistributedType =
1309  VectorType::get(newDistributedShape, distributedType.getElementType());
1310  SmallVector<size_t> newRetIndices;
1311  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1312  rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1313  newRetIndices);
1314  rewriter.setInsertionPointAfter(newWarpOp);
1315  SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1316  extractOp.getSizes(), [](Attribute attr) { return attr; });
1317  // Update the distributed sizes to match the distributed type.
1318  if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
1319  distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1320  distributedType.getDimSize(distributedDim));
1321 
1322  // Create a new extract strided slice op that extracts from the
1323  // distributed vector.
1324  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1325  Value newExtract = vector::ExtractStridedSliceOp::create(
1326  rewriter, extractOp.getLoc(), distributedType, distributedVec,
1327  extractOp.getOffsets(),
1328  ArrayAttr::get(rewriter.getContext(), distributedSizes),
1329  extractOp.getStrides());
1330  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1331  newExtract);
1332  return success();
1333  }
1334 };
1335 
1336 /// Pattern to move out vector.extract of single element vector. Those don't
1337 /// need to be distributed and can just be propagated outside of the region.
1338 struct WarpOpExtract : public WarpDistributionPattern {
1339  using Base::Base;
1340  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1341  PatternRewriter &rewriter) const override {
1342  OpOperand *operand =
1343  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1344  if (!operand)
1345  return failure();
1346  unsigned int operandNumber = operand->getOperandNumber();
1347  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1348  VectorType extractSrcType = extractOp.getSourceVectorType();
1349  Location loc = extractOp.getLoc();
1350 
1351  // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1352  if (extractSrcType.getRank() <= 1) {
1353  return failure();
1354  }
1355 
1356  // All following cases are 2d or higher dimensional source vectors.
1357 
1358  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1359  // There is no distribution, this is a broadcast. Simply move the extract
1360  // out of the warp op.
1361  // TODO: This could be optimized. E.g., in case of a scalar result, let
1362  // one lane extract and shuffle the result to all other lanes (same as
1363  // the 1d case).
1364  SmallVector<size_t> newRetIndices;
1365  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1366  rewriter, warpOp, {extractOp.getVector()},
1367  {extractOp.getSourceVectorType()}, newRetIndices);
1368  rewriter.setInsertionPointAfter(newWarpOp);
1369  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1370  // Extract from distributed vector.
1371  Value newExtract = vector::ExtractOp::create(
1372  rewriter, loc, distributedVec, extractOp.getMixedPosition());
1373  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1374  newExtract);
1375  return success();
1376  }
1377 
1378  // Find the distributed dimension. There should be exactly one.
1379  auto distributedType =
1380  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1381  auto yieldedType = cast<VectorType>(operand->get().getType());
1382  int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1383  assert(distributedDim != -1 && "could not find distributed dimension");
1384  (void)distributedDim;
1385 
1386  // Yield source vector from warp op.
1387  SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1388  for (int i = 0; i < distributedType.getRank(); ++i)
1389  newDistributedShape[i + extractOp.getNumIndices()] =
1390  distributedType.getDimSize(i);
1391  auto newDistributedType =
1392  VectorType::get(newDistributedShape, distributedType.getElementType());
1393  SmallVector<size_t> newRetIndices;
1394  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1395  rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1396  newRetIndices);
1397  rewriter.setInsertionPointAfter(newWarpOp);
1398  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1399  // Extract from distributed vector.
1400  Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1401  extractOp.getMixedPosition());
1402  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1403  newExtract);
1404  return success();
1405  }
1406 };
1407 
1408 /// Pattern to move out vector.extract with a scalar result.
1409 /// Only supports 1-D and 0-D sources for now.
1410 struct WarpOpExtractScalar : public WarpDistributionPattern {
1411  WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1412  PatternBenefit b = 1)
1413  : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1414  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1415  PatternRewriter &rewriter) const override {
1416  OpOperand *operand =
1417  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1418  if (!operand)
1419  return failure();
1420  unsigned int operandNumber = operand->getOperandNumber();
1421  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1422  VectorType extractSrcType = extractOp.getSourceVectorType();
1423  // Only supports 1-D or 0-D sources for now.
1424  if (extractSrcType.getRank() > 1) {
1425  return rewriter.notifyMatchFailure(
1426  extractOp, "only 0-D or 1-D source supported for now");
1427  }
1428  // TODO: Supported shuffle types should be parameterizable, similar to
1429  // `WarpShuffleFromIdxFn`.
1430  if (!extractSrcType.getElementType().isF32() &&
1431  !extractSrcType.getElementType().isInteger(32))
1432  return rewriter.notifyMatchFailure(
1433  extractOp, "only f32/i32 element types are supported");
1434  bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1435  Type elType = extractSrcType.getElementType();
1436  VectorType distributedVecType;
1437  if (!is0dOrVec1Extract) {
1438  assert(extractSrcType.getRank() == 1 &&
1439  "expected that extract src rank is 0 or 1");
1440  if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1441  return failure();
1442  int64_t elementsPerLane =
1443  extractSrcType.getShape()[0] / warpOp.getWarpSize();
1444  distributedVecType = VectorType::get({elementsPerLane}, elType);
1445  } else {
1446  distributedVecType = extractSrcType;
1447  }
1448  // Yield source vector and position (if present) from warp op.
1449  SmallVector<Value> additionalResults{extractOp.getVector()};
1450  SmallVector<Type> additionalResultTypes{distributedVecType};
1451  additionalResults.append(
1452  SmallVector<Value>(extractOp.getDynamicPosition()));
1453  additionalResultTypes.append(
1454  SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1455 
1456  Location loc = extractOp.getLoc();
1457  SmallVector<size_t> newRetIndices;
1458  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1459  rewriter, warpOp, additionalResults, additionalResultTypes,
1460  newRetIndices);
1461  rewriter.setInsertionPointAfter(newWarpOp);
1462  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1463 
1464  // 0d extract: The new warp op broadcasts the source vector to all lanes.
1465  // All lanes extract the scalar.
1466  if (is0dOrVec1Extract) {
1467  Value newExtract;
1468  SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1469  newExtract =
1470  vector::ExtractOp::create(rewriter, loc, distributedVec, indices);
1471  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1472  newExtract);
1473  return success();
1474  }
1475 
1476  int64_t staticPos = extractOp.getStaticPosition()[0];
1477  OpFoldResult pos = ShapedType::isDynamic(staticPos)
1478  ? (newWarpOp->getResult(newRetIndices[1]))
1479  : OpFoldResult(rewriter.getIndexAttr(staticPos));
1480  // 1d extract: Distribute the source vector. One lane extracts and shuffles
1481  // the value to all other lanes.
1482  int64_t elementsPerLane = distributedVecType.getShape()[0];
1483  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1484  // tid of extracting thread: pos / elementsPerLane
1485  Value broadcastFromTid = affine::makeComposedAffineApply(
1486  rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1487  // Extract at position: pos % elementsPerLane
1488  Value newPos =
1489  elementsPerLane == 1
1490  ? arith::ConstantIndexOp::create(rewriter, loc, 0).getResult()
1491  : affine::makeComposedAffineApply(rewriter, loc,
1492  sym0 % elementsPerLane, pos);
1493  Value extracted =
1494  vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1495 
1496  // Shuffle the extracted value to all lanes.
1497  Value shuffled = warpShuffleFromIdxFn(
1498  loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1499  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1500  return success();
1501  }
1502 
1503 private:
1504  WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1505 };
1506 
1507 /// Pattern to move out vector.insert with a scalar input.
1508 /// Only supports 1-D and 0-D destinations for now.
1509 struct WarpOpInsertScalar : public WarpDistributionPattern {
1510  using Base::Base;
1511  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1512  PatternRewriter &rewriter) const override {
1513  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1514  if (!operand)
1515  return failure();
1516  unsigned int operandNumber = operand->getOperandNumber();
1517  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1518  VectorType vecType = insertOp.getDestVectorType();
1519  VectorType distrType =
1520  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1521 
1522  // Only supports 1-D or 0-D destinations for now.
1523  if (vecType.getRank() > 1) {
1524  return rewriter.notifyMatchFailure(
1525  insertOp, "only 0-D or 1-D source supported for now");
1526  }
1527 
1528  // Yield destination vector, source scalar and position from warp op.
1529  SmallVector<Value> additionalResults{insertOp.getDest(),
1530  insertOp.getValueToStore()};
1531  SmallVector<Type> additionalResultTypes{
1532  distrType, insertOp.getValueToStore().getType()};
1533  additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1534  additionalResultTypes.append(
1535  SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1536 
1537  Location loc = insertOp.getLoc();
1538  SmallVector<size_t> newRetIndices;
1539  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1540  rewriter, warpOp, additionalResults, additionalResultTypes,
1541  newRetIndices);
1542  rewriter.setInsertionPointAfter(newWarpOp);
1543  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1544  Value newSource = newWarpOp->getResult(newRetIndices[1]);
1545  rewriter.setInsertionPointAfter(newWarpOp);
1546 
1547  OpFoldResult pos;
1548  if (vecType.getRank() != 0) {
1549  int64_t staticPos = insertOp.getStaticPosition()[0];
1550  pos = ShapedType::isDynamic(staticPos)
1551  ? (newWarpOp->getResult(newRetIndices[2]))
1552  : OpFoldResult(rewriter.getIndexAttr(staticPos));
1553  }
1554 
1555  // This condition is always true for 0-d vectors.
1556  if (vecType == distrType) {
1557  Value newInsert;
1558  SmallVector<OpFoldResult> indices;
1559  if (pos) {
1560  indices.push_back(pos);
1561  }
1562  newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1563  distributedVec, indices);
1564  // Broadcast: Simply move the vector.insert op out.
1565  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1566  newInsert);
1567  return success();
1568  }
1569 
1570  // This is a distribution. Only one lane should insert.
1571  int64_t elementsPerLane = distrType.getShape()[0];
1572  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1573  // tid of extracting thread: pos / elementsPerLane
1574  Value insertingLane = affine::makeComposedAffineApply(
1575  rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1576  // Insert position: pos % elementsPerLane
1578  rewriter, loc, sym0 % elementsPerLane, pos);
1579  Value isInsertingLane =
1580  arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1581  newWarpOp.getLaneid(), insertingLane);
1582  Value newResult =
1583  scf::IfOp::create(
1584  rewriter, loc, isInsertingLane,
1585  /*thenBuilder=*/
1586  [&](OpBuilder &builder, Location loc) {
1587  Value newInsert = vector::InsertOp::create(
1588  builder, loc, newSource, distributedVec, newPos);
1589  scf::YieldOp::create(builder, loc, newInsert);
1590  },
1591  /*elseBuilder=*/
1592  [&](OpBuilder &builder, Location loc) {
1593  scf::YieldOp::create(builder, loc, distributedVec);
1594  })
1595  .getResult(0);
1596  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1597  return success();
1598  }
1599 };
1600 
1601 struct WarpOpInsert : public WarpDistributionPattern {
1602  using Base::Base;
1603  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1604  PatternRewriter &rewriter) const override {
1605  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1606  if (!operand)
1607  return failure();
1608  unsigned int operandNumber = operand->getOperandNumber();
1609  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1610  Location loc = insertOp.getLoc();
1611 
1612  // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1613  if (insertOp.getDestVectorType().getRank() <= 1) {
1614  return failure();
1615  }
1616 
1617  // All following cases are 2d or higher dimensional source vectors.
1618 
1619  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1620  // There is no distribution, this is a broadcast. Simply move the insert
1621  // out of the warp op.
1622  SmallVector<size_t> newRetIndices;
1623  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1624  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1625  {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1626  newRetIndices);
1627  rewriter.setInsertionPointAfter(newWarpOp);
1628  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1629  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1630  Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1631  distributedDest,
1632  insertOp.getMixedPosition());
1633  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1634  newResult);
1635  return success();
1636  }
1637 
1638  // Find the distributed dimension. There should be exactly one.
1639  auto distrDestType =
1640  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1641  auto yieldedType = cast<VectorType>(operand->get().getType());
1642  int64_t distrDestDim = -1;
1643  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1644  if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1645  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1646  // support distributing multiple dimensions in the future.
1647  assert(distrDestDim == -1 && "found multiple distributed dims");
1648  distrDestDim = i;
1649  }
1650  }
1651  assert(distrDestDim != -1 && "could not find distributed dimension");
1652 
1653  // Compute the distributed source vector type.
1654  VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1655  SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1656  // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1657  // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1658  // insert a smaller vector<3xf32>.
1659  // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1660  // case, one lane will insert the source vector<96xf32>. The other
1661  // lanes will not do anything.
1662  int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1663  if (distrSrcDim >= 0)
1664  distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1665  auto distrSrcType =
1666  VectorType::get(distrSrcShape, distrDestType.getElementType());
1667 
1668  // Yield source and dest vectors from warp op.
1669  SmallVector<size_t> newRetIndices;
1670  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1671  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1672  {distrSrcType, distrDestType}, newRetIndices);
1673  rewriter.setInsertionPointAfter(newWarpOp);
1674  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1675  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1676 
1677  // Insert into the distributed vector.
1678  Value newResult;
1679  if (distrSrcDim >= 0) {
1680  // Every lane inserts a small piece.
1681  newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1682  distributedDest,
1683  insertOp.getMixedPosition());
1684  } else {
1685  // One lane inserts the entire source vector.
1686  int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1687  SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1688  SmallVector<int64_t> newPos = getAsIntegers(pos);
1689  // tid of inserting lane: pos / elementsPerLane
1690  Value insertingLane = arith::ConstantIndexOp::create(
1691  rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1692  Value isInsertingLane =
1693  arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1694  newWarpOp.getLaneid(), insertingLane);
1695  // Insert position: pos % elementsPerLane
1696  newPos[distrDestDim] %= elementsPerLane;
1697  auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1698  Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1699  distributedDest, newPos);
1700  scf::YieldOp::create(builder, loc, newInsert);
1701  };
1702  auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1703  scf::YieldOp::create(builder, loc, distributedDest);
1704  };
1705  newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1706  /*thenBuilder=*/insertingBuilder,
1707  /*elseBuilder=*/nonInsertingBuilder)
1708  .getResult(0);
1709  }
1710 
1711  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1712  return success();
1713  }
1714 };
1715 
1716 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1717 /// the scf.ForOp is the last operation in the region so that it doesn't
1718 /// change the order of execution. This creates a new scf.for region after the
1719 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1720 /// WarpExecuteOnLane0Op region. Example:
1721 /// ```
1722 /// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1723 /// ...
1724 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1725 /// -> (vector<128xf32>) {
1726 /// ...
1727 /// scf.yield %r : vector<128xf32>
1728 /// }
1729 /// gpu.yield %v1 : vector<128xf32>
1730 /// }
1731 /// ```
1732 /// To:
1733 /// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1734 /// ...
1735 /// gpu.yield %v : vector<128xf32>
1736 /// }
1737 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1738 /// -> (vector<4xf32>) {
1739 /// %iw = gpu.warp_execute_on_lane_0(%laneid)
1740 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1741 /// ^bb0(%arg: vector<128xf32>):
1742 /// ...
1743 /// gpu.yield %ir : vector<128xf32>
1744 /// }
1745 /// scf.yield %iw : vector<4xf32>
1746 /// }
1747 /// ```
1748 struct WarpOpScfForOp : public WarpDistributionPattern {
1749 
1750  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1751  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1752  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1753  PatternRewriter &rewriter) const override {
1754  gpu::YieldOp warpOpYield = warpOp.getTerminator();
1755  // Only pick up `ForOp` if it is the last op in the region.
1756  Operation *lastNode = warpOpYield->getPrevNode();
1757  auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1758  if (!forOp)
1759  return failure();
1760  // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
1761  // Those Values need to be returned by the new warp op.
1762  llvm::SmallSetVector<Value, 32> escapingValues;
1763  SmallVector<Type> escapingValueInputTypes;
1764  SmallVector<Type> escapingValueDistTypes;
1766  forOp.getBodyRegion(), [&](OpOperand *operand) {
1767  Operation *parent = operand->get().getParentRegion()->getParentOp();
1768  if (warpOp->isAncestor(parent)) {
1769  if (!escapingValues.insert(operand->get()))
1770  return;
1771  Type distType = operand->get().getType();
1772  if (auto vecType = dyn_cast<VectorType>(distType)) {
1773  AffineMap map = distributionMapFn(operand->get());
1774  distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1775  }
1776  escapingValueInputTypes.push_back(operand->get().getType());
1777  escapingValueDistTypes.push_back(distType);
1778  }
1779  });
1780 
1781  if (llvm::is_contained(escapingValueDistTypes, Type{}))
1782  return failure();
1783  // `WarpOp` can yield two types of values:
1784  // 1. Values that are not results of the `ForOp`:
1785  // These values must also be yielded by the new `WarpOp`. Also, we need
1786  // to record the index mapping for these values to replace them later.
1787  // 2. Values that are results of the `ForOp`:
1788  // In this case, we record the index mapping between the `WarpOp` result
1789  // index and matching `ForOp` result index.
1790  // Additionally, we keep track of the distributed types for all `ForOp`
1791  // vector results.
1792  SmallVector<Value> nonForYieldedValues;
1793  SmallVector<unsigned> nonForResultIndices;
1794  llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
1795  llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
1796  for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1797  // Yielded value is not a result of the forOp.
1798  if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
1799  nonForYieldedValues.push_back(yieldOperand.get());
1800  nonForResultIndices.push_back(yieldOperand.getOperandNumber());
1801  continue;
1802  }
1803  OpResult forResult = cast<OpResult>(yieldOperand.get());
1804  unsigned int forResultNumber = forResult.getResultNumber();
1805  forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber;
1806  // If this `ForOp` result is vector type and it is yielded by the
1807  // `WarpOp`, we keep track the distributed type for this result.
1808  if (!isa<VectorType>(forResult.getType()))
1809  continue;
1810  VectorType distType = cast<VectorType>(
1811  warpOp.getResult(yieldOperand.getOperandNumber()).getType());
1812  forResultDistTypes[forResultNumber] = distType;
1813  }
1814 
1815  // Newly created `WarpOp` will yield values in following order:
1816  // 1. All init args of the `ForOp`.
1817  // 2. All escaping values.
1818  // 3. All non-`ForOp` yielded values.
1819  SmallVector<Value> newWarpOpYieldValues;
1820  SmallVector<Type> newWarpOpDistTypes;
1821  for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
1822  newWarpOpYieldValues.push_back(initArg);
1823  // Compute the distributed type for this init arg.
1824  Type distType = initArg.getType();
1825  if (auto vecType = dyn_cast<VectorType>(distType)) {
1826  // If the `ForOp` result corresponds to this init arg is already yielded
1827  // we can get the distributed type from `forResultDistTypes` map.
1828  // Otherwise, we compute it using distributionMapFn.
1829  AffineMap map = distributionMapFn(initArg);
1830  distType = forResultDistTypes.count(i)
1831  ? forResultDistTypes[i]
1832  : getDistributedType(vecType, map, warpOp.getWarpSize());
1833  }
1834  newWarpOpDistTypes.push_back(distType);
1835  }
1836  // Insert escaping values and their distributed types.
1837  newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
1838  escapingValues.begin(), escapingValues.end());
1839  newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
1840  escapingValueDistTypes.begin(),
1841  escapingValueDistTypes.end());
1842  // Next, we insert all non-`ForOp` yielded values and their distributed
1843  // types. We also create a mapping between the non-`ForOp` yielded value
1844  // index and the corresponding new `WarpOp` yield value index (needed to
1845  // update users later).
1846  llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
1847  for (auto [i, v] :
1848  llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
1849  nonForResultMapping[i] = newWarpOpYieldValues.size();
1850  newWarpOpYieldValues.push_back(v);
1851  newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
1852  }
1853  // Create the new `WarpOp` with the updated yield values and types.
1854  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1855  rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1856 
1857  // Next, we create a new `ForOp` with the init args yielded by the new
1858  // `WarpOp`.
1859  const unsigned escapingValuesStartIdx =
1860  forOp.getInitArgs().size(); // `ForOp` init args are positioned before
1861  // escaping values in the new `WarpOp`.
1862  SmallVector<Value> newForOpOperands;
1863  for (size_t i = 0; i < escapingValuesStartIdx; ++i)
1864  newForOpOperands.push_back(newWarpOp.getResult(i));
1865 
1866  // Create a new `ForOp` outside the new `WarpOp` region.
1867  OpBuilder::InsertionGuard g(rewriter);
1868  rewriter.setInsertionPointAfter(newWarpOp);
1869  auto newForOp = scf::ForOp::create(
1870  rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1871  forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
1872  forOp.getUnsignedCmp());
1873  // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
1874  // newly created `ForOp`. This `WarpOp` will contain all ops that were
1875  // contained within the original `ForOp` body.
1876  rewriter.setInsertionPointToStart(newForOp.getBody());
1877 
1878  SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
1879  newForOp.getRegionIterArgs().end());
1880  SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
1881  forOp.getResultTypes().end());
1882  // Escaping values are forwarded to the inner `WarpOp` as its (additional)
1883  // arguments. We keep track of the mapping between these values and their
1884  // argument index in the inner `WarpOp` (to replace users later).
1885  llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1886  for (size_t i = escapingValuesStartIdx;
1887  i < escapingValuesStartIdx + escapingValues.size(); ++i) {
1888  innerWarpInput.push_back(newWarpOp.getResult(i));
1889  argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1890  innerWarpInputType.size();
1891  innerWarpInputType.push_back(
1892  escapingValueInputTypes[i - escapingValuesStartIdx]);
1893  }
1894  // Create the inner `WarpOp` with the new input values and types.
1895  auto innerWarp = WarpExecuteOnLane0Op::create(
1896  rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
1897  newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
1898  innerWarpInputType);
1899 
1900  // Inline the `ForOp` body into the inner `WarpOp` body.
1901  SmallVector<Value> argMapping;
1902  argMapping.push_back(newForOp.getInductionVar());
1903  for (Value args : innerWarp.getBody()->getArguments())
1904  argMapping.push_back(args);
1905 
1906  argMapping.resize(forOp.getBody()->getNumArguments());
1907  SmallVector<Value> yieldOperands;
1908  for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1909  yieldOperands.push_back(operand);
1910 
1911  rewriter.eraseOp(forOp.getBody()->getTerminator());
1912  rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1913 
1914  // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
1915  // original `ForOp` results.
1916  rewriter.setInsertionPointToEnd(innerWarp.getBody());
1917  gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1918  rewriter.setInsertionPointAfter(innerWarp);
1919  // Insert a scf.yield op at the end of the new `ForOp` body that yields
1920  // the inner `WarpOp` results.
1921  if (!innerWarp.getResults().empty())
1922  scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
1923 
1924  // Update the users of original `WarpOp` results that were coming from the
1925  // original `ForOp` to the corresponding new `ForOp` result.
1926  for (auto [origIdx, newIdx] : forResultMapping)
1927  rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
1928  newForOp.getResult(newIdx), newForOp);
1929  // Similarly, update any users of the `WarpOp` results that were not
1930  // results of the `ForOp`.
1931  for (auto [origIdx, newIdx] : nonForResultMapping)
1932  rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
1933  newWarpOp.getResult(newIdx));
1934  // Remove the original `WarpOp` and `ForOp`, they should not have any uses
1935  // at this point.
1936  rewriter.eraseOp(forOp);
1937  rewriter.eraseOp(warpOp);
1938  // Update any users of escaping values that were forwarded to the
1939  // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
1940  newForOp.walk([&](Operation *op) {
1941  for (OpOperand &operand : op->getOpOperands()) {
1942  auto it = argIndexMapping.find(operand.get());
1943  if (it == argIndexMapping.end())
1944  continue;
1945  operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1946  }
1947  });
1948 
1949  // Finally, hoist out any now uniform code from the inner `WarpOp`.
1950  mlir::vector::moveScalarUniformCode(innerWarp);
1951  return success();
1952  }
1953 
1954 private:
1955  DistributionMapFn distributionMapFn;
1956 };
1957 
1958 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1959 /// The vector is reduced in parallel. Currently limited to vector size
1960 /// matching the warpOp size. E.g.:
1961 /// ```
1962 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1963 /// %0 = "some_def"() : () -> (vector<32xf32>)
1964 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1965 /// gpu.yield %1 : f32
1966 /// }
1967 /// ```
1968 /// is lowered to:
1969 /// ```
1970 /// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1971 /// %1 = "some_def"() : () -> (vector<32xf32>)
1972 /// gpu.yield %1 : vector<32xf32>
1973 /// }
1974 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
1975 /// %r = ("warp.reduction %a")
1976 /// ```
1977 struct WarpOpReduction : public WarpDistributionPattern {
1978  WarpOpReduction(MLIRContext *context,
1979  DistributedReductionFn distributedReductionFn,
1980  PatternBenefit benefit = 1)
1981  : WarpDistributionPattern(context, benefit),
1982  distributedReductionFn(std::move(distributedReductionFn)) {}
1983 
1984  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1985  PatternRewriter &rewriter) const override {
1986  OpOperand *yieldOperand =
1987  getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1988  if (!yieldOperand)
1989  return failure();
1990 
1991  auto reductionOp =
1992  cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1993  auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1994  // Only rank 1 vectors supported.
1995  if (vectorType.getRank() != 1)
1996  return rewriter.notifyMatchFailure(
1997  warpOp, "Only rank 1 reductions can be distributed.");
1998  // Only warp_size-sized vectors supported.
1999  if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2000  return rewriter.notifyMatchFailure(
2001  warpOp, "Reduction vector dimension must match was size.");
2002  if (!reductionOp.getType().isIntOrFloat())
2003  return rewriter.notifyMatchFailure(
2004  warpOp, "Reduction distribution currently only supports floats and "
2005  "integer types.");
2006 
2007  int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2008  // Return vector that will be reduced from the WarpExecuteOnLane0Op.
2009  unsigned operandIndex = yieldOperand->getOperandNumber();
2010  SmallVector<Value> yieldValues = {reductionOp.getVector()};
2011  SmallVector<Type> retTypes = {
2012  VectorType::get({numElements}, reductionOp.getType())};
2013  if (reductionOp.getAcc()) {
2014  yieldValues.push_back(reductionOp.getAcc());
2015  retTypes.push_back(reductionOp.getAcc().getType());
2016  }
2017  SmallVector<size_t> newRetIndices;
2018  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2019  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2020  rewriter.setInsertionPointAfter(newWarpOp);
2021 
2022  // Obtain data to reduce for a single lane.
2023  Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2024  // Distribute and reduce across threads.
2025  Value fullReduce =
2026  distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2027  reductionOp.getKind(), newWarpOp.getWarpSize());
2028  if (reductionOp.getAcc()) {
2029  fullReduce = vector::makeArithReduction(
2030  rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2031  newWarpOp.getResult(newRetIndices[1]));
2032  }
2033  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
2034  return success();
2035  }
2036 
2037 private:
2038  DistributedReductionFn distributedReductionFn;
2039 };
2040 
2041 } // namespace
2042 
2046  patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
2047 }
2048 
2049 void mlir::vector::populateDistributeTransferWriteOpPatterns(
2050  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2051  unsigned maxNumElementsToExtract, PatternBenefit benefit) {
2052  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
2053  maxNumElementsToExtract, benefit);
2054 }
2055 
2056 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2057  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2058  const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
2059  PatternBenefit readBenefit) {
2060  patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
2061  patterns
2062  .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2063  WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
2064  WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
2065  WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2066  patterns.getContext(), benefit);
2067  patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
2068  benefit);
2069  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
2070  benefit);
2071 }
2072 
2073 void mlir::vector::populateDistributeReduction(
2075  const DistributedReductionFn &distributedReductionFn,
2076  PatternBenefit benefit) {
2077  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
2078  benefit);
2079 }
2080 
2081 /// Helper to know if an op can be hoisted out of the region.
2082 static bool canBeHoisted(Operation *op,
2083  function_ref<bool(Value)> definedOutside) {
2084  return llvm::all_of(op->getOperands(), definedOutside) &&
2085  isMemoryEffectFree(op) && op->getNumRegions() == 0;
2086 }
2087 
2088 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2089  Block *body = warpOp.getBody();
2090 
2091  // Keep track of the ops we want to hoist.
2092  llvm::SmallSetVector<Operation *, 8> opsToMove;
2093 
2094  // Helper to check if a value is or will be defined outside of the region.
2095  auto isDefinedOutsideOfBody = [&](Value value) {
2096  auto *definingOp = value.getDefiningOp();
2097  return (definingOp && opsToMove.count(definingOp)) ||
2098  warpOp.isDefinedOutsideOfRegion(value);
2099  };
2100 
2101  // Do not use walk here, as we do not want to go into nested regions and hoist
2102  // operations from there.
2103  for (auto &op : body->without_terminator()) {
2104  bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
2105  return isa<VectorType>(result.getType());
2106  });
2107  if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
2108  opsToMove.insert(&op);
2109  }
2110 
2111  // Move all the ops marked as uniform outside of the region.
2112  for (Operation *op : opsToMove)
2113  op->moveBefore(warpOp);
2114 }
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:411
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:341
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:367
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:554
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:700
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:612
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:39
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1276
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2781
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:345
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:43
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:629
This represents an operation in an abstracted form, suitable for use with the builder APIs.