MLIR  22.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinTypes.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVectorExtras.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <utility>
26 
27 using namespace mlir;
28 using namespace mlir::vector;
29 using namespace mlir::gpu;
30 
31 /// Currently the distribution map is implicit based on the vector shape. In the
32 /// future it will be part of the op.
33 /// Example:
34 /// ```
35 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
36 /// ...
37 /// gpu.yield %3 : vector<32x16x64xf32>
38 /// }
39 /// ```
40 /// Would have an implicit map of:
41 /// `(d0, d1, d2) -> (d0, d2)`
42 static AffineMap calculateImplicitMap(VectorType sequentialType,
43  VectorType distributedType) {
45  perm.reserve(1);
46  // Check which dimensions of the sequential type are different than the
47  // dimensions of the distributed type to know the distributed dimensions. Then
48  // associate each distributed dimension to an ID in order.
49  for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50  if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
51  perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
52  }
53  auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
54  distributedType.getContext());
55  return map;
56 }
57 
58 /// Given a sequential and distributed vector type, returns the distributed
59 /// dimension. This function expects that only a single dimension is
60 /// distributed.
61 static int getDistributedDim(VectorType sequentialType,
62  VectorType distributedType) {
63  assert(sequentialType.getRank() == distributedType.getRank() &&
64  "sequential and distributed vector types must have the same rank");
65  int64_t distributedDim = -1;
66  for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
67  if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
68  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
69  // support distributing multiple dimensions in the future.
70  assert(distributedDim == -1 && "found multiple distributed dims");
71  distributedDim = i;
72  }
73  }
74  return distributedDim;
75 }
76 
77 namespace {
78 
79 /// Helper struct to create the load / store operations that permit transit
80 /// through the parallel / sequential and the sequential / parallel boundaries
81 /// when performing `rewriteWarpOpToScfFor`.
82 ///
83 /// The vector distribution dimension is inferred from the vector types.
84 struct DistributedLoadStoreHelper {
85  DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
86  Value laneId, Value zero)
87  : sequentialVal(sequentialVal), distributedVal(distributedVal),
88  laneId(laneId), zero(zero) {
89  sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
90  distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
91  if (sequentialVectorType && distributedVectorType)
92  distributionMap =
93  calculateImplicitMap(sequentialVectorType, distributedVectorType);
94  }
95 
96  Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
97  int64_t distributedSize = distributedVectorType.getDimSize(index);
99  return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
100  ArrayRef<Value>{laneId});
101  }
102 
103  /// Create a store during the process of distributing the
104  /// `vector.warp_execute_on_thread_0` op.
105  /// Vector distribution assumes the following convention regarding the
106  /// temporary buffers that are created to transition values. This **must**
107  /// be properly specified in the `options.warpAllocationFn`:
108  /// 1. scalars of type T transit through a memref<1xT>.
109  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
110  Operation *buildStore(RewriterBase &b, Location loc, Value val,
111  Value buffer) {
112  assert((val == distributedVal || val == sequentialVal) &&
113  "Must store either the preregistered distributed or the "
114  "preregistered sequential value.");
115  // Scalar case can directly use memref.store.
116  if (!isa<VectorType>(val.getType()))
117  return b.create<memref::StoreOp>(loc, val, buffer, zero);
118 
119  // Vector case must use vector::TransferWriteOp which will later lower to
120  // vector.store of memref.store depending on further lowerings.
121  int64_t rank = sequentialVectorType.getRank();
122  SmallVector<Value> indices(rank, zero);
123  if (val == distributedVal) {
124  for (auto dimExpr : distributionMap.getResults()) {
125  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126  indices[index] = buildDistributedOffset(b, loc, index);
127  }
128  }
129  SmallVector<bool> inBounds(indices.size(), true);
130  return b.create<vector::TransferWriteOp>(
131  loc, val, buffer, indices,
132  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
133  }
134 
135  /// Create a load during the process of distributing the
136  /// `vector.warp_execute_on_thread_0` op.
137  /// Vector distribution assumes the following convention regarding the
138  /// temporary buffers that are created to transition values. This **must**
139  /// be properly specified in the `options.warpAllocationFn`:
140  /// 1. scalars of type T transit through a memref<1xT>.
141  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
142  ///
143  /// When broadcastMode is true, the load is not distributed to account for
144  /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
145  ///
146  /// Example:
147  ///
148  /// ```
149  /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
150  /// gpu.yield %cst : f32
151  /// }
152  /// // Both types are f32. The constant %cst is broadcasted to all lanes.
153  /// ```
154  /// This behavior described in more detail in the documentation of the op.
155  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
156 
157  // Scalar case can directly use memref.store.
158  if (!isa<VectorType>(type))
159  return b.create<memref::LoadOp>(loc, buffer, zero);
160 
161  // Other cases must be vector atm.
162  // Vector case must use vector::TransferReadOp which will later lower to
163  // vector.read of memref.read depending on further lowerings.
164  assert((type == distributedVectorType || type == sequentialVectorType) &&
165  "Must store either the preregistered distributed or the "
166  "preregistered sequential type.");
167  SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
168  if (type == distributedVectorType) {
169  for (auto dimExpr : distributionMap.getResults()) {
170  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171  indices[index] = buildDistributedOffset(b, loc, index);
172  }
173  }
174  SmallVector<bool> inBounds(indices.size(), true);
175  return b.create<vector::TransferReadOp>(
176  loc, cast<VectorType>(type), buffer, indices, /*padding=*/std::nullopt,
177  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
178  }
179 
180  Value sequentialVal, distributedVal, laneId, zero;
181  VectorType sequentialVectorType, distributedVectorType;
182  AffineMap distributionMap;
183 };
184 
185 } // namespace
186 
187 // Clones `op` into a new operation that takes `operands` and returns
188 // `resultTypes`.
190  Location loc, Operation *op,
191  ArrayRef<Value> operands,
192  ArrayRef<Type> resultTypes) {
193  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
194  op->getAttrs());
195  return rewriter.create(res);
196 }
197 
198 namespace {
199 
200 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
201 /// thread `laneId` executes the entirety of the computation.
202 ///
203 /// After the transformation:
204 /// - the IR within the scf.if op can be thought of as executing sequentially
205 /// (from the point of view of threads along `laneId`).
206 /// - the IR outside of the scf.if op can be thought of as executing in
207 /// parallel (from the point of view of threads along `laneId`).
208 ///
209 /// Values that need to transit through the parallel / sequential and the
210 /// sequential / parallel boundaries do so via reads and writes to a temporary
211 /// memory location.
212 ///
213 /// The transformation proceeds in multiple steps:
214 /// 1. Create the scf.if op.
215 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
216 /// within the scf.if to transit the values captured from above.
217 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
218 /// consistent within the scf.if.
219 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
220 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to
221 /// transit the values returned by the op.
222 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
223 /// consistent after the scf.if.
224 /// 7. Perform late cleanups.
225 ///
226 /// All this assumes the vector distribution occurs along the most minor
227 /// distributed vector dimension.
228 struct WarpOpToScfIfPattern : public WarpDistributionPattern {
229  WarpOpToScfIfPattern(MLIRContext *context,
231  PatternBenefit benefit = 1)
232  : WarpDistributionPattern(context, benefit), options(options) {}
233 
234  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
235  PatternRewriter &rewriter) const override {
236  assert(warpOp.getBodyRegion().hasOneBlock() &&
237  "expected WarpOp with single block");
238  Block *warpOpBody = &warpOp.getBodyRegion().front();
239  Location loc = warpOp.getLoc();
240 
241  // Passed all checks. Start rewriting.
242  OpBuilder::InsertionGuard g(rewriter);
243  rewriter.setInsertionPoint(warpOp);
244 
245  // Step 1: Create scf.if op.
246  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
247  Value isLane0 = rewriter.create<arith::CmpIOp>(
248  loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
249  auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
250  /*withElseRegion=*/false);
251  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
252 
253  // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
254  // reads within the scf.if to transit the values captured from above.
255  SmallVector<Value> bbArgReplacements;
256  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
257  Value sequentialVal = warpOpBody->getArgument(it.index());
258  Value distributedVal = it.value();
259  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
260  warpOp.getLaneid(), c0);
261 
262  // Create buffer before the ifOp.
263  rewriter.setInsertionPoint(ifOp);
264  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
265  sequentialVal.getType());
266  // Store distributed vector into buffer, before the ifOp.
267  helper.buildStore(rewriter, loc, distributedVal, buffer);
268  // Load sequential vector from buffer, inside the ifOp.
269  rewriter.setInsertionPointToStart(ifOp.thenBlock());
270  bbArgReplacements.push_back(
271  helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
272  }
273 
274  // Step 3. Insert sync after all the stores and before all the loads.
275  if (!warpOp.getArgs().empty()) {
276  rewriter.setInsertionPoint(ifOp);
277  options.warpSyncronizationFn(loc, rewriter, warpOp);
278  }
279 
280  // Step 4. Move body of warpOp to ifOp.
281  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
282 
283  // Step 5. Insert appropriate writes within scf.if and reads after the
284  // scf.if to transit the values returned by the op.
285  // TODO: at this point, we can reuse the shared memory from previous
286  // buffers.
287  SmallVector<Value> replacements;
288  auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
289  Location yieldLoc = yieldOp.getLoc();
290  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
291  Value sequentialVal = it.value();
292  Value distributedVal = warpOp->getResult(it.index());
293  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
294  warpOp.getLaneid(), c0);
295 
296  // Create buffer before the ifOp.
297  rewriter.setInsertionPoint(ifOp);
298  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
299  sequentialVal.getType());
300 
301  // Store yielded value into buffer, inside the ifOp, before the
302  // terminator.
303  rewriter.setInsertionPoint(yieldOp);
304  helper.buildStore(rewriter, loc, sequentialVal, buffer);
305 
306  // Load distributed value from buffer, after the warpOp.
307  rewriter.setInsertionPointAfter(ifOp);
308  // Result type and yielded value type are the same. This is a broadcast.
309  // E.g.:
310  // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
311  // gpu.yield %cst : f32
312  // }
313  // Both types are f32. The constant %cst is broadcasted to all lanes.
314  // This is described in more detail in the documentation of the op.
315  replacements.push_back(
316  helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
317  }
318 
319  // Step 6. Insert sync after all the stores and before all the loads.
320  if (!yieldOp.getOperands().empty()) {
321  rewriter.setInsertionPointAfter(ifOp);
322  options.warpSyncronizationFn(loc, rewriter, warpOp);
323  }
324 
325  // Step 7. Delete terminator and add empty scf.yield.
326  rewriter.eraseOp(yieldOp);
327  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
328  rewriter.create<scf::YieldOp>(yieldLoc);
329 
330  // Compute replacements for WarpOp results.
331  rewriter.replaceOp(warpOp, replacements);
332 
333  return success();
334  }
335 
336 private:
338 };
339 
340 /// Return the distributed vector type based on the original type and the
341 /// distribution map. The map is expected to have a dimension equal to the
342 /// original type rank and should be a projection where the results are the
343 /// distributed dimensions. The number of results should be equal to the number
344 /// of warp sizes which is currently limited to 1.
345 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
346 /// and a warp size of 16 would distribute the second dimension (associated to
347 /// d1) and return vector<16x2x64>
348 static VectorType getDistributedType(VectorType originalType, AffineMap map,
349  int64_t warpSize) {
350  SmallVector<int64_t> targetShape(originalType.getShape());
351  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
352  unsigned position = map.getDimPosition(i);
353  if (targetShape[position] % warpSize != 0) {
354  if (warpSize % targetShape[position] != 0) {
355  return VectorType();
356  }
357  warpSize /= targetShape[position];
358  targetShape[position] = 1;
359  continue;
360  }
361  targetShape[position] = targetShape[position] / warpSize;
362  warpSize = 1;
363  break;
364  }
365  if (warpSize != 1) {
366  return VectorType();
367  }
368  VectorType targetType =
369  VectorType::get(targetShape, originalType.getElementType());
370  return targetType;
371 }
372 
373 /// Distribute transfer_write ops based on the affine map returned by
374 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
375 /// will not be distributed (it should be less than the warp size).
376 ///
377 /// Example:
378 /// ```
379 /// %0 = gpu.warp_execute_on_lane_0(%id){
380 /// ...
381 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
382 /// gpu.yield
383 /// }
384 /// ```
385 /// To
386 /// ```
387 /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
388 /// ...
389 /// gpu.yield %v : vector<32xf32>
390 /// }
391 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
392 struct WarpOpTransferWrite : public WarpDistributionPattern {
393  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
394  unsigned maxNumElementsToExtract, PatternBenefit b = 1)
395  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
396  maxNumElementsToExtract(maxNumElementsToExtract) {}
397 
398  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
399  /// are multiples of the distribution ratio are supported at the moment.
400  LogicalResult tryDistributeOp(RewriterBase &rewriter,
401  vector::TransferWriteOp writeOp,
402  WarpExecuteOnLane0Op warpOp) const {
403  VectorType writtenVectorType = writeOp.getVectorType();
404 
405  // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
406  // to separate it from the rest.
407  if (writtenVectorType.getRank() == 0)
408  return failure();
409 
410  // 2. Compute the distributed type.
411  AffineMap map = distributionMapFn(writeOp.getVector());
412  VectorType targetType =
413  getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
414  if (!targetType)
415  return failure();
416 
417  // 2.5 Compute the distributed type for the new mask;
418  VectorType maskType;
419  if (writeOp.getMask()) {
420  // TODO: Distribution of masked writes with non-trivial permutation maps
421  // requires the distribution of the mask to elementwise match the
422  // distribution of the permuted written vector. Currently the details
423  // of which lane is responsible for which element is captured strictly
424  // by shape information on the warp op, and thus requires materializing
425  // the permutation in IR.
426  if (!writeOp.getPermutationMap().isMinorIdentity())
427  return failure();
428  maskType =
429  getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
430  }
431 
432  // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
433  // the rest.
434  vector::TransferWriteOp newWriteOp =
435  cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
436 
437  // 4. Reindex the write using the distribution map.
438  auto newWarpOp =
439  newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
440 
441  // Delinearize the lane id based on the way threads are divided across the
442  // vector. To get the number of threads per vector dimension, divide the
443  // sequential size by the distributed size along each dim.
444  rewriter.setInsertionPoint(newWriteOp);
445  SmallVector<OpFoldResult> delinearizedIdSizes;
446  for (auto [seqSize, distSize] :
447  llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
448  assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
449  delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
450  }
451  SmallVector<Value> delinearized;
452  if (map.getNumResults() > 1) {
453  delinearized = rewriter
454  .create<mlir::affine::AffineDelinearizeIndexOp>(
455  newWarpOp.getLoc(), newWarpOp.getLaneid(),
456  delinearizedIdSizes)
457  .getResults();
458  } else {
459  // If there is only one map result, we can elide the delinearization
460  // op and use the lane id directly.
461  delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
462  }
463 
464  AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
465  Location loc = newWriteOp.getLoc();
466  SmallVector<Value> indices(newWriteOp.getIndices().begin(),
467  newWriteOp.getIndices().end());
468  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
469  AffineExpr d0, d1;
470  bindDims(newWarpOp.getContext(), d0, d1);
471  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
472  if (!indexExpr)
473  continue;
474  unsigned indexPos = indexExpr.getPosition();
475  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
476  Value laneId = delinearized[vectorPos];
477  auto scale =
478  rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
479  indices[indexPos] = affine::makeComposedAffineApply(
480  rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
481  }
482  newWriteOp.getIndicesMutable().assign(indices);
483 
484  return success();
485  }
486 
487  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
488  LogicalResult tryExtractOp(RewriterBase &rewriter,
489  vector::TransferWriteOp writeOp,
490  WarpExecuteOnLane0Op warpOp) const {
491  Location loc = writeOp.getLoc();
492  VectorType vecType = writeOp.getVectorType();
493 
494  if (vecType.getNumElements() > maxNumElementsToExtract) {
495  return rewriter.notifyMatchFailure(
496  warpOp,
497  llvm::formatv(
498  "writes more elements ({0}) than allowed to extract ({1})",
499  vecType.getNumElements(), maxNumElementsToExtract));
500  }
501 
502  // Do not process warp ops that contain only TransferWriteOps.
503  if (llvm::all_of(warpOp.getOps(),
504  llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
505  return failure();
506 
507  SmallVector<Value> yieldValues = {writeOp.getVector()};
508  SmallVector<Type> retTypes = {vecType};
509  SmallVector<size_t> newRetIndices;
510  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
511  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
512  rewriter.setInsertionPointAfter(newWarpOp);
513 
514  // Create a second warp op that contains only writeOp.
515  auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
516  loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
517  Block &body = secondWarpOp.getBodyRegion().front();
518  rewriter.setInsertionPointToStart(&body);
519  auto newWriteOp =
520  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
521  newWriteOp.getValueToStoreMutable().assign(
522  newWarpOp.getResult(newRetIndices[0]));
523  rewriter.eraseOp(writeOp);
524  rewriter.create<gpu::YieldOp>(newWarpOp.getLoc());
525  return success();
526  }
527 
528  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
529  PatternRewriter &rewriter) const override {
530  auto yield = cast<gpu::YieldOp>(
531  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
532  Operation *lastNode = yield->getPrevNode();
533  auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
534  if (!writeOp)
535  return failure();
536 
537  Value maybeMask = writeOp.getMask();
538  if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
539  return writeOp.getVector() == value ||
540  (maybeMask && maybeMask == value) ||
541  warpOp.isDefinedOutsideOfRegion(value);
542  }))
543  return failure();
544 
545  if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
546  return success();
547 
548  // Masked writes not supported for extraction.
549  if (writeOp.getMask())
550  return failure();
551 
552  if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
553  return success();
554 
555  return failure();
556  }
557 
558 private:
559  /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
560  /// execute op with the proper return type. The new write op is updated to
561  /// write the result of the new warp execute op. The old `writeOp` is deleted.
562  vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
563  WarpExecuteOnLane0Op warpOp,
564  vector::TransferWriteOp writeOp,
565  VectorType targetType,
566  VectorType maybeMaskType) const {
567  assert(writeOp->getParentOp() == warpOp &&
568  "write must be nested immediately under warp");
569  OpBuilder::InsertionGuard g(rewriter);
570  SmallVector<size_t> newRetIndices;
571  WarpExecuteOnLane0Op newWarpOp;
572  if (maybeMaskType) {
573  newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
574  rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
575  TypeRange{targetType, maybeMaskType}, newRetIndices);
576  } else {
577  newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
578  rewriter, warpOp, ValueRange{{writeOp.getVector()}},
579  TypeRange{targetType}, newRetIndices);
580  }
581  rewriter.setInsertionPointAfter(newWarpOp);
582  auto newWriteOp =
583  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
584  rewriter.eraseOp(writeOp);
585  newWriteOp.getValueToStoreMutable().assign(
586  newWarpOp.getResult(newRetIndices[0]));
587  if (maybeMaskType)
588  newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
589  return newWriteOp;
590  }
591 
592  DistributionMapFn distributionMapFn;
593  unsigned maxNumElementsToExtract = 1;
594 };
595 
596 /// Sink out elementwise op feeding into a warp op yield.
597 /// ```
598 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
599 /// ...
600 /// %3 = arith.addf %1, %2 : vector<32xf32>
601 /// gpu.yield %3 : vector<32xf32>
602 /// }
603 /// ```
604 /// To
605 /// ```
606 /// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
607 /// vector<1xf32>, vector<1xf32>) {
608 /// ...
609 /// %4 = arith.addf %2, %3 : vector<32xf32>
610 /// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
611 /// vector<32xf32>
612 /// }
613 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
614 struct WarpOpElementwise : public WarpDistributionPattern {
615  using Base::Base;
616  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
617  PatternRewriter &rewriter) const override {
618  OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
620  });
621  if (!yieldOperand)
622  return failure();
623 
624  Operation *elementWise = yieldOperand->get().getDefiningOp();
625  unsigned operandIndex = yieldOperand->getOperandNumber();
626  Value distributedVal = warpOp.getResult(operandIndex);
627  SmallVector<Value> yieldValues;
628  SmallVector<Type> retTypes;
629  Location loc = warpOp.getLoc();
630  for (OpOperand &operand : elementWise->getOpOperands()) {
631  Type targetType;
632  if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
633  // If the result type is a vector, the operands must also be vectors.
634  auto operandType = cast<VectorType>(operand.get().getType());
635  targetType =
636  VectorType::get(vecType.getShape(), operandType.getElementType());
637  } else {
638  auto operandType = operand.get().getType();
639  assert(!isa<VectorType>(operandType) &&
640  "unexpected yield of vector from op with scalar result type");
641  targetType = operandType;
642  }
643  retTypes.push_back(targetType);
644  yieldValues.push_back(operand.get());
645  }
646  SmallVector<size_t> newRetIndices;
647  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
648  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
649  rewriter.setInsertionPointAfter(newWarpOp);
650  SmallVector<Value> newOperands(elementWise->getOperands().begin(),
651  elementWise->getOperands().end());
652  for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
653  newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
654  }
655  OpBuilder::InsertionGuard g(rewriter);
656  rewriter.setInsertionPointAfter(newWarpOp);
658  rewriter, loc, elementWise, newOperands,
659  {newWarpOp.getResult(operandIndex).getType()});
660  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
661  newOp->getResult(0));
662  return success();
663  }
664 };
665 
666 /// Sink out splat constant op feeding into a warp op yield.
667 /// ```
668 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
669 /// ...
670 /// %cst = arith.constant dense<2.0> : vector<32xf32>
671 /// gpu.yield %cst : vector<32xf32>
672 /// }
673 /// ```
674 /// To
675 /// ```
676 /// gpu.warp_execute_on_lane_0(%arg0 {
677 /// ...
678 /// }
679 /// %0 = arith.constant dense<2.0> : vector<1xf32>
680 struct WarpOpConstant : public WarpDistributionPattern {
681  using Base::Base;
682  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
683  PatternRewriter &rewriter) const override {
684  OpOperand *yieldOperand =
685  getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
686  if (!yieldOperand)
687  return failure();
688  auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
689  auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
690  if (!dense)
691  return failure();
692  // Notify the rewriter that the warp op is changing (see the comment on
693  // the WarpOpTransferRead pattern).
694  rewriter.startOpModification(warpOp);
695  unsigned operandIndex = yieldOperand->getOperandNumber();
696  Attribute scalarAttr = dense.getSplatValue<Attribute>();
697  auto newAttr = DenseElementsAttr::get(
698  cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
699  Location loc = warpOp.getLoc();
700  rewriter.setInsertionPointAfter(warpOp);
701  Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
702  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
703  rewriter.finalizeOpModification(warpOp);
704  return success();
705  }
706 };
707 
708 /// Sink out transfer_read op feeding into a warp op yield.
709 /// ```
710 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
711 /// ...
712 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
713 // vector<32xf32>
714 /// gpu.yield %2 : vector<32xf32>
715 /// }
716 /// ```
717 /// To
718 /// ```
719 /// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
720 /// vector<1xf32>, vector<1xf32>) {
721 /// ...
722 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
723 /// vector<32xf32> gpu.yield %2 : vector<32xf32>
724 /// }
725 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
726 struct WarpOpTransferRead : public WarpDistributionPattern {
727  using Base::Base;
728  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
729  PatternRewriter &rewriter) const override {
730  // Try to find a distributable yielded read. Note that this pattern can
731  // still fail at the end after distribution, in which case this might have
732  // missed another distributable read.
733  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
734  // Don't duplicate transfer_read ops when distributing.
735  return isa<vector::TransferReadOp>(op) && op->hasOneUse();
736  });
737  if (!operand)
738  return rewriter.notifyMatchFailure(
739  warpOp, "warp result is not a vector.transfer_read op");
740  auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
741 
742  // Source must be defined outside of the region.
743  if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
744  return rewriter.notifyMatchFailure(
745  read, "source must be defined outside of the region");
746 
747  unsigned operandIndex = operand->getOperandNumber();
748  Value distributedVal = warpOp.getResult(operandIndex);
749 
750  SmallVector<Value, 4> indices(read.getIndices().begin(),
751  read.getIndices().end());
752  auto sequentialType = cast<VectorType>(read.getResult().getType());
753  auto distributedType = cast<VectorType>(distributedVal.getType());
754  AffineMap map = calculateImplicitMap(sequentialType, distributedType);
755  AffineMap indexMap = map.compose(read.getPermutationMap());
756 
757  // Try to delinearize the lane ID to match the rank expected for
758  // distribution.
759  SmallVector<Value> delinearizedIds;
760  if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
761  distributedType.getShape(), warpOp.getWarpSize(),
762  warpOp.getLaneid(), delinearizedIds)) {
763  return rewriter.notifyMatchFailure(
764  read, "cannot delinearize lane ID for distribution");
765  }
766  assert(!delinearizedIds.empty() || map.getNumResults() == 0);
767 
768  // Distribute indices and the mask (if present).
769  OpBuilder::InsertionGuard g(rewriter);
770  SmallVector<Value> additionalResults(indices.begin(), indices.end());
771  SmallVector<Type> additionalResultTypes(indices.size(),
772  rewriter.getIndexType());
773  additionalResults.push_back(read.getPadding());
774  additionalResultTypes.push_back(read.getPadding().getType());
775 
776  bool hasMask = false;
777  if (read.getMask()) {
778  hasMask = true;
779  // TODO: Distribution of masked reads with non-trivial permutation maps
780  // requires the distribution of the mask to elementwise match the
781  // distribution of the permuted written vector. Currently the details
782  // of which lane is responsible for which element is captured strictly
783  // by shape information on the warp op, and thus requires materializing
784  // the permutation in IR.
785  if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
786  return rewriter.notifyMatchFailure(
787  read, "non-trivial permutation maps not supported");
788  VectorType maskType =
789  getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
790  additionalResults.push_back(read.getMask());
791  additionalResultTypes.push_back(maskType);
792  }
793 
794  SmallVector<size_t> newRetIndices;
795  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
796  rewriter, warpOp, additionalResults, additionalResultTypes,
797  newRetIndices);
798  distributedVal = newWarpOp.getResult(operandIndex);
799 
800  // Distributed indices were appended first.
801  SmallVector<Value> newIndices;
802  for (int64_t i = 0, e = indices.size(); i < e; ++i)
803  newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
804 
805  rewriter.setInsertionPointAfter(newWarpOp);
806  for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
807  AffineExpr d0, d1;
808  bindDims(read.getContext(), d0, d1);
809  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
810  if (!indexExpr)
811  continue;
812  unsigned indexPos = indexExpr.getPosition();
813  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
814  int64_t scale = distributedType.getDimSize(vectorPos);
815  newIndices[indexPos] = affine::makeComposedAffineApply(
816  rewriter, read.getLoc(), d0 + scale * d1,
817  {newIndices[indexPos], delinearizedIds[vectorPos]});
818  }
819 
820  // Distributed padding value was appended right after the indices.
821  Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
822  // Distributed mask value was added at the end (if the op has a mask).
823  Value newMask =
824  hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
825  : Value();
826  auto newRead = rewriter.create<vector::TransferReadOp>(
827  read.getLoc(), distributedVal.getType(), read.getBase(), newIndices,
828  read.getPermutationMapAttr(), newPadding, newMask,
829  read.getInBoundsAttr());
830 
831  rewriter.replaceAllUsesWith(distributedVal, newRead);
832  return success();
833  }
834 };
835 
836 /// Remove any result that has no use along with the matching yieldOp operand.
837 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
838 struct WarpOpDeadResult : public WarpDistributionPattern {
839  using Base::Base;
840  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
841  PatternRewriter &rewriter) const override {
842  SmallVector<Type> newResultTypes;
843  newResultTypes.reserve(warpOp->getNumResults());
844  SmallVector<Value> newYieldValues;
845  newYieldValues.reserve(warpOp->getNumResults());
846  DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
847  DenseMap<OpResult, int64_t> dedupResultPositionMap;
848  auto yield = cast<gpu::YieldOp>(
849  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
850 
851  // Some values may be yielded multiple times and correspond to multiple
852  // results. Deduplicating occurs by taking each result with its matching
853  // yielded value, and:
854  // 1. recording the unique first position at which the value is yielded.
855  // 2. recording for the result, the first position at which the dedup'ed
856  // value is yielded.
857  // 3. skipping from the new result types / new yielded values any result
858  // that has no use or whose yielded value has already been seen.
859  for (OpResult result : warpOp.getResults()) {
860  Value yieldOperand = yield.getOperand(result.getResultNumber());
861  auto it = dedupYieldOperandPositionMap.insert(
862  std::make_pair(yieldOperand, newResultTypes.size()));
863  dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
864  if (result.use_empty() || !it.second)
865  continue;
866  newResultTypes.push_back(result.getType());
867  newYieldValues.push_back(yieldOperand);
868  }
869  // No modification, exit early.
870  if (yield.getNumOperands() == newYieldValues.size())
871  return failure();
872  // Move the body of the old warpOp to a new warpOp.
873  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
874  rewriter, warpOp, newYieldValues, newResultTypes);
875 
876  // Simplify the new warp op after dropping dead results.
877  newWarpOp.getBody()->walk([&](Operation *op) {
878  if (isOpTriviallyDead(op))
879  rewriter.eraseOp(op);
880  });
881 
882  // Replace results of the old warpOp by the new, deduplicated results.
883  SmallVector<Value> newValues;
884  newValues.reserve(warpOp->getNumResults());
885  for (OpResult result : warpOp.getResults()) {
886  if (result.use_empty())
887  newValues.push_back(Value());
888  else
889  newValues.push_back(
890  newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
891  }
892  rewriter.replaceOp(warpOp, newValues);
893  return success();
894  }
895 };
896 
897 // If an operand is directly yielded out of the region we can forward it
898 // directly and it doesn't need to go through the region.
899 struct WarpOpForwardOperand : public WarpDistributionPattern {
900  using Base::Base;
901  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
902  PatternRewriter &rewriter) const override {
903  auto yield = cast<gpu::YieldOp>(
904  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
905  Value valForwarded;
906  unsigned resultIndex;
907  for (OpOperand &operand : yield->getOpOperands()) {
908  Value result = warpOp.getResult(operand.getOperandNumber());
909  if (result.use_empty())
910  continue;
911 
912  // Assume all the values coming from above are uniform.
913  if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
914  if (result.getType() != operand.get().getType())
915  continue;
916  valForwarded = operand.get();
917  resultIndex = operand.getOperandNumber();
918  break;
919  }
920  auto arg = dyn_cast<BlockArgument>(operand.get());
921  if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
922  continue;
923  Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
924  if (result.getType() != warpOperand.getType())
925  continue;
926  valForwarded = warpOperand;
927  resultIndex = operand.getOperandNumber();
928  break;
929  }
930  if (!valForwarded)
931  return failure();
932  // Notify the rewriter that the warp op is changing (see the comment on
933  // the WarpOpTransferRead pattern).
934  rewriter.startOpModification(warpOp);
935  rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
936  rewriter.finalizeOpModification(warpOp);
937  return success();
938  }
939 };
940 
941 struct WarpOpBroadcast : public WarpDistributionPattern {
942  using Base::Base;
943  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
944  PatternRewriter &rewriter) const override {
945  OpOperand *operand =
946  getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
947  if (!operand)
948  return failure();
949  unsigned int operandNumber = operand->getOperandNumber();
950  auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
951  Location loc = broadcastOp.getLoc();
952  auto destVecType =
953  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
954  Value broadcastSrc = broadcastOp.getSource();
955  Type broadcastSrcType = broadcastSrc.getType();
956 
957  // Check that the broadcast actually spans a set of values uniformly across
958  // all threads. In other words, check that each thread can reconstruct
959  // their own broadcast.
960  // For that we simply check that the broadcast we want to build makes sense.
961  if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
963  return failure();
964  SmallVector<size_t> newRetIndices;
965  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
966  rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
967  rewriter.setInsertionPointAfter(newWarpOp);
968  Value broadcasted = rewriter.create<vector::BroadcastOp>(
969  loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
970  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
971  broadcasted);
972  return success();
973  }
974 };
975 
976 /// Pattern to move shape cast out of the warp op. shape cast is basically a
977 /// no-op for warp distribution; we need to handle the shape though.
978 struct WarpOpShapeCast : public WarpDistributionPattern {
979  using Base::Base;
980  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
981  PatternRewriter &rewriter) const override {
982  OpOperand *operand =
983  getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
984  if (!operand)
985  return failure();
986 
987  auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
988 
989  unsigned int operandNumber = operand->getOperandNumber();
990  auto castDistributedType =
991  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
992  VectorType castOriginalType = oldCastOp.getSourceVectorType();
993  VectorType castResultType = castDistributedType;
994 
995  // We expect the distributed type to have a smaller rank than the original
996  // type. Prepend with size-one dimensions to make them the same.
997  unsigned castDistributedRank = castDistributedType.getRank();
998  unsigned castOriginalRank = castOriginalType.getRank();
999  if (castDistributedRank < castOriginalRank) {
1000  SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1001  llvm::append_range(shape, castDistributedType.getShape());
1002  castDistributedType =
1003  VectorType::get(shape, castDistributedType.getElementType());
1004  }
1005 
1006  SmallVector<size_t> newRetIndices;
1007  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1008  rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1009  newRetIndices);
1010  rewriter.setInsertionPointAfter(newWarpOp);
1011  Value newCast = rewriter.create<vector::ShapeCastOp>(
1012  oldCastOp.getLoc(), castResultType,
1013  newWarpOp->getResult(newRetIndices[0]));
1014  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1015  return success();
1016  }
1017 };
1018 
1019 /// Sink out vector.create_mask op feeding into a warp op yield.
1020 /// ```
1021 /// %0 = ...
1022 /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1023 /// ...
1024 /// %mask = vector.create_mask %0 : vector<32xi1>
1025 /// gpu.yield %mask : vector<32xi1>
1026 /// }
1027 /// ```
1028 /// To
1029 /// ```
1030 /// %0 = ...
1031 /// gpu.warp_execute_on_lane_0(%arg0) {
1032 /// ...
1033 /// }
1034 /// %cmp = arith.cmpi ult, %laneid, %0
1035 /// %ub = arith.select %cmp, %c0, %c1
1036 /// %1 = vector.create_mask %ub : vector<1xi1>
1037 struct WarpOpCreateMask : public WarpDistributionPattern {
1038  using Base::Base;
1039  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1040  PatternRewriter &rewriter) const override {
1041  OpOperand *yieldOperand =
1042  getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1043  if (!yieldOperand)
1044  return failure();
1045 
1046  auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1047 
1048  // Early exit if any values needed for calculating the new mask indices
1049  // are defined inside the warp op.
1050  if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1051  return warpOp.isDefinedOutsideOfRegion(value);
1052  }))
1053  return failure();
1054 
1055  Location loc = mask.getLoc();
1056  unsigned operandIndex = yieldOperand->getOperandNumber();
1057 
1058  auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1059  VectorType seqType = mask.getVectorType();
1060  ArrayRef<int64_t> seqShape = seqType.getShape();
1061  ArrayRef<int64_t> distShape = distType.getShape();
1062 
1063  rewriter.setInsertionPointAfter(warpOp);
1064 
1065  // Delinearize the lane ID for constructing the distributed mask sizes.
1066  SmallVector<Value> delinearizedIds;
1067  if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1068  warpOp.getWarpSize(), warpOp.getLaneid(),
1069  delinearizedIds))
1070  return rewriter.notifyMatchFailure(
1071  mask, "cannot delinearize lane ID for distribution");
1072  assert(!delinearizedIds.empty());
1073 
1074  // Notify the rewriter that the warp op is changing (see the comment on
1075  // the WarpOpTransferRead pattern).
1076  rewriter.startOpModification(warpOp);
1077 
1078  AffineExpr s0, s1;
1079  bindSymbols(rewriter.getContext(), s0, s1);
1080  SmallVector<Value> newOperands;
1081  for (int i = 0, e = distShape.size(); i < e; ++i) {
1082  // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1083  // find the distance from the largest mask index owned by this lane to the
1084  // original mask size. `vector.create_mask` implicitly clamps mask
1085  // operands to the range [0, mask_vector_size[i]], or in other words, the
1086  // mask sizes are always in the range [0, mask_vector_size[i]).
1088  rewriter, loc, s1 - s0 * distShape[i],
1089  {delinearizedIds[i], mask.getOperand(i)});
1090  newOperands.push_back(maskDimIdx);
1091  }
1092 
1093  auto newMask =
1094  rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1095  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1096  rewriter.finalizeOpModification(warpOp);
1097  return success();
1098  }
1099 };
1100 
1101 /// Sink out insert_strided_slice op feeding into a warp op yield.
1102 /// ```
1103 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
1104 /// ...
1105 /// %src = ... : vector<4x32xf32>
1106 /// %dest = ... : vector<8x32xf32>
1107 /// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1108 /// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
1109 /// gpu.yield %insert : vector<8x32xf32>
1110 /// }
1111 /// ```
1112 /// To
1113 /// ```
1114 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
1115 /// vector<8x1xf32>) {
1116 /// ...
1117 /// %src = ... : vector<4x32xf32>
1118 /// %dest = ... : vector<8x32xf32>
1119 /// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
1120 /// }
1121 /// %insert = vector.insert_strided_slice %0#0, %0#1,
1122 /// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
1123 /// ```
1124 /// NOTE: Current support assumes that both src and dest vectors are distributed
1125 /// to lanes and sinking the insert op does not require any cross lane
1126 /// communication.
1127 struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
1128  using Base::Base;
1129  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1130  PatternRewriter &rewriter) const override {
1131  OpOperand *operand =
1132  getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1133  if (!operand)
1134  return failure();
1135  unsigned int operandNumber = operand->getOperandNumber();
1136  auto insertOp =
1137  operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1138  auto distributedType =
1139  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1140  // Distributed type must be 2D or higher.
1141  // TODO: Support 1D distributed types.
1142  if (distributedType.getRank() < 2)
1143  return rewriter.notifyMatchFailure(
1144  insertOp, "result vector type must be 2D or higher");
1145  // Find the distributed dimension of the dest vector. There should be
1146  // exactly one.
1147  auto yieldedType = cast<VectorType>(operand->get().getType());
1148  int64_t destDistributedDim =
1149  getDistributedDim(yieldedType, distributedType);
1150  assert(destDistributedDim != -1 && "could not find distributed dimension");
1151 
1152  VectorType srcType = insertOp.getSourceVectorType();
1153  VectorType destType = insertOp.getDestVectorType();
1154  // Currently we require that both source (kD) and dest (nD) vectors are
1155  // distributed. This requires that distributedDim (d) is contained in the
1156  // last k dims of the dest vector (d >= n - k).
1157  // TODO: Add support for case where source vector is not distributed.
1158  int64_t sourceDistributedDim =
1159  destDistributedDim - (destType.getRank() - srcType.getRank());
1160  if (sourceDistributedDim < 0)
1161  return rewriter.notifyMatchFailure(
1162  insertOp,
1163  "distributed dimension must be in the last k dims of dest vector");
1164  // Distributed dimension must be fully inserted.
1165  if (srcType.getDimSize(sourceDistributedDim) !=
1166  destType.getDimSize(destDistributedDim))
1167  return rewriter.notifyMatchFailure(
1168  insertOp, "distributed dimension must be fully inserted");
1169  SmallVector<int64_t> newSourceDistShape(
1170  insertOp.getSourceVectorType().getShape());
1171  newSourceDistShape[sourceDistributedDim] =
1172  distributedType.getDimSize(destDistributedDim);
1173  auto newSourceTy =
1174  VectorType::get(newSourceDistShape, distributedType.getElementType());
1175  VectorType newDestTy = distributedType;
1176  SmallVector<size_t> newRetIndices;
1177  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1178  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1179  {newSourceTy, newDestTy}, newRetIndices);
1180  rewriter.setInsertionPointAfter(newWarpOp);
1181  Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1182  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1183  // Create a new insert strided slice op that inserts distributed source into
1184  // distributed dest.
1185  Value newInsert = rewriter.create<vector::InsertStridedSliceOp>(
1186  insertOp.getLoc(), distributedDest.getType(), distributedSource,
1187  distributedDest, insertOp.getOffsets(), insertOp.getStrides());
1188  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
1189  return success();
1190  }
1191 };
1192 
1193 /// Sink out extract_strided_slice op feeding into a warp op yield.
1194 /// ```
1195 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
1196 /// ...
1197 /// %src = ... : vector<64x32xf32>
1198 /// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1199 /// strides = [1] : vector<64x32xf32> to vector<16x32xf32>
1200 /// gpu.yield %extract : vector<16x32xf32>
1201 /// }
1202 /// ```
1203 /// To
1204 /// ```
1205 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
1206 /// ...
1207 /// %src = ... : vector<64x32xf32>
1208 /// gpu.yield %src : vector<64x32xf32>
1209 /// }
1210 /// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1211 /// strides = [1] : vector<64x1xf32> to vector<16x1xf32>
1212 /// ```
1213 /// NOTE: Current support assumes that the extraction happens only on non
1214 /// distributed dimensions (does not require cross lane communication).
1215 struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
1216  using Base::Base;
1217  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1218  PatternRewriter &rewriter) const override {
1219  OpOperand *operand =
1220  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1221  if (!operand)
1222  return failure();
1223  unsigned int operandNumber = operand->getOperandNumber();
1224  auto extractOp =
1225  operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
1226  auto distributedType =
1227  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1228  // Distributed type must be 2D or higher.
1229  // TODO: Support 1D distributed types.
1230  if (distributedType.getRank() < 2)
1231  return rewriter.notifyMatchFailure(
1232  extractOp, "result vector type must be 2D or higher");
1233 
1234  // Find the distributed dimension. There should be exactly one.
1235  auto yieldedType = cast<VectorType>(operand->get().getType());
1236  int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1237  assert(distributedDim != -1 && "could not find distributed dimension");
1238 
1239  int64_t numOfExtractedDims =
1240  static_cast<int64_t>(extractOp.getSizes().size());
1241  // If the distributed dim is included in the extracted dims, then we make
1242  // sure distributed dim is fully extracted. If distributed dim is not
1243  // included in extracted dims, it is guaranteed to be fully extracted (i.e.
1244  // distributed dim comes after all the extracted dims)
1245  // TODO: Partial extraction from distributed dimension require cross lane
1246  // communication.
1247  if (distributedDim < numOfExtractedDims) {
1248  int64_t distributedDimOffset =
1249  llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1250  .getInt();
1251  int64_t distributedDimSize =
1252  llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1253  .getInt();
1254  if (distributedDimOffset != 0 ||
1255  distributedDimSize != yieldedType.getDimSize(distributedDim))
1256  return rewriter.notifyMatchFailure(
1257  extractOp, "distributed dimension must be fully extracted");
1258  }
1259  SmallVector<int64_t> newDistributedShape(
1260  extractOp.getSourceVectorType().getShape());
1261  newDistributedShape[distributedDim] =
1262  distributedType.getDimSize(distributedDim);
1263  auto newDistributedType =
1264  VectorType::get(newDistributedShape, distributedType.getElementType());
1265  SmallVector<size_t> newRetIndices;
1266  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1267  rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1268  newRetIndices);
1269  rewriter.setInsertionPointAfter(newWarpOp);
1270  SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1271  extractOp.getSizes(), [](Attribute attr) { return attr; });
1272  // Update the distributed sizes to match the distributed type.
1273  if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
1274  distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1275  distributedType.getDimSize(distributedDim));
1276 
1277  // Create a new extract strided slice op that extracts from the
1278  // distributed vector.
1279  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1280  Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
1281  extractOp.getLoc(), distributedType, distributedVec,
1282  extractOp.getOffsets(),
1283  ArrayAttr::get(rewriter.getContext(), distributedSizes),
1284  extractOp.getStrides());
1285  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1286  newExtract);
1287  return success();
1288  }
1289 };
1290 
1291 /// Pattern to move out vector.extract of single element vector. Those don't
1292 /// need to be distributed and can just be propagated outside of the region.
1293 struct WarpOpExtract : public WarpDistributionPattern {
1294  using Base::Base;
1295  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1296  PatternRewriter &rewriter) const override {
1297  OpOperand *operand =
1298  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1299  if (!operand)
1300  return failure();
1301  unsigned int operandNumber = operand->getOperandNumber();
1302  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1303  VectorType extractSrcType = extractOp.getSourceVectorType();
1304  Location loc = extractOp.getLoc();
1305 
1306  // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1307  if (extractSrcType.getRank() <= 1) {
1308  return failure();
1309  }
1310 
1311  // All following cases are 2d or higher dimensional source vectors.
1312 
1313  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1314  // There is no distribution, this is a broadcast. Simply move the extract
1315  // out of the warp op.
1316  // TODO: This could be optimized. E.g., in case of a scalar result, let
1317  // one lane extract and shuffle the result to all other lanes (same as
1318  // the 1d case).
1319  SmallVector<size_t> newRetIndices;
1320  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1321  rewriter, warpOp, {extractOp.getVector()},
1322  {extractOp.getSourceVectorType()}, newRetIndices);
1323  rewriter.setInsertionPointAfter(newWarpOp);
1324  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1325  // Extract from distributed vector.
1326  Value newExtract = rewriter.create<vector::ExtractOp>(
1327  loc, distributedVec, extractOp.getMixedPosition());
1328  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1329  newExtract);
1330  return success();
1331  }
1332 
1333  // Find the distributed dimension. There should be exactly one.
1334  auto distributedType =
1335  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1336  auto yieldedType = cast<VectorType>(operand->get().getType());
1337  int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1338  assert(distributedDim != -1 && "could not find distributed dimension");
1339  (void)distributedDim;
1340 
1341  // Yield source vector from warp op.
1342  SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1343  for (int i = 0; i < distributedType.getRank(); ++i)
1344  newDistributedShape[i + extractOp.getNumIndices()] =
1345  distributedType.getDimSize(i);
1346  auto newDistributedType =
1347  VectorType::get(newDistributedShape, distributedType.getElementType());
1348  SmallVector<size_t> newRetIndices;
1349  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1350  rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1351  newRetIndices);
1352  rewriter.setInsertionPointAfter(newWarpOp);
1353  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1354  // Extract from distributed vector.
1355  Value newExtract = rewriter.create<vector::ExtractOp>(
1356  loc, distributedVec, extractOp.getMixedPosition());
1357  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1358  newExtract);
1359  return success();
1360  }
1361 };
1362 
1363 /// Pattern to move out vector.extract with a scalar result.
1364 /// Only supports 1-D and 0-D sources for now.
1365 struct WarpOpExtractScalar : public WarpDistributionPattern {
1366  WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1367  PatternBenefit b = 1)
1368  : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1369  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1370  PatternRewriter &rewriter) const override {
1371  OpOperand *operand =
1372  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1373  if (!operand)
1374  return failure();
1375  unsigned int operandNumber = operand->getOperandNumber();
1376  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1377  VectorType extractSrcType = extractOp.getSourceVectorType();
1378  // Only supports 1-D or 0-D sources for now.
1379  if (extractSrcType.getRank() > 1) {
1380  return rewriter.notifyMatchFailure(
1381  extractOp, "only 0-D or 1-D source supported for now");
1382  }
1383  // TODO: Supported shuffle types should be parameterizable, similar to
1384  // `WarpShuffleFromIdxFn`.
1385  if (!extractSrcType.getElementType().isF32() &&
1386  !extractSrcType.getElementType().isInteger(32))
1387  return rewriter.notifyMatchFailure(
1388  extractOp, "only f32/i32 element types are supported");
1389  bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1390  Type elType = extractSrcType.getElementType();
1391  VectorType distributedVecType;
1392  if (!is0dOrVec1Extract) {
1393  assert(extractSrcType.getRank() == 1 &&
1394  "expected that extract src rank is 0 or 1");
1395  if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1396  return failure();
1397  int64_t elementsPerLane =
1398  extractSrcType.getShape()[0] / warpOp.getWarpSize();
1399  distributedVecType = VectorType::get({elementsPerLane}, elType);
1400  } else {
1401  distributedVecType = extractSrcType;
1402  }
1403  // Yield source vector and position (if present) from warp op.
1404  SmallVector<Value> additionalResults{extractOp.getVector()};
1405  SmallVector<Type> additionalResultTypes{distributedVecType};
1406  additionalResults.append(
1407  SmallVector<Value>(extractOp.getDynamicPosition()));
1408  additionalResultTypes.append(
1409  SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1410 
1411  Location loc = extractOp.getLoc();
1412  SmallVector<size_t> newRetIndices;
1413  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1414  rewriter, warpOp, additionalResults, additionalResultTypes,
1415  newRetIndices);
1416  rewriter.setInsertionPointAfter(newWarpOp);
1417  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1418 
1419  // 0d extract: The new warp op broadcasts the source vector to all lanes.
1420  // All lanes extract the scalar.
1421  if (is0dOrVec1Extract) {
1422  Value newExtract;
1423  SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1424  newExtract =
1425  rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
1426  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1427  newExtract);
1428  return success();
1429  }
1430 
1431  int64_t staticPos = extractOp.getStaticPosition()[0];
1432  OpFoldResult pos = ShapedType::isDynamic(staticPos)
1433  ? (newWarpOp->getResult(newRetIndices[1]))
1434  : OpFoldResult(rewriter.getIndexAttr(staticPos));
1435  // 1d extract: Distribute the source vector. One lane extracts and shuffles
1436  // the value to all other lanes.
1437  int64_t elementsPerLane = distributedVecType.getShape()[0];
1438  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1439  // tid of extracting thread: pos / elementsPerLane
1440  Value broadcastFromTid = affine::makeComposedAffineApply(
1441  rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1442  // Extract at position: pos % elementsPerLane
1443  Value newPos =
1444  elementsPerLane == 1
1445  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1446  : affine::makeComposedAffineApply(rewriter, loc,
1447  sym0 % elementsPerLane, pos);
1448  Value extracted =
1449  rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
1450 
1451  // Shuffle the extracted value to all lanes.
1452  Value shuffled = warpShuffleFromIdxFn(
1453  loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1454  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1455  return success();
1456  }
1457 
1458 private:
1459  WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1460 };
1461 
1462 /// Pattern to move out vector.insert with a scalar input.
1463 /// Only supports 1-D and 0-D destinations for now.
1464 struct WarpOpInsertScalar : public WarpDistributionPattern {
1465  using Base::Base;
1466  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1467  PatternRewriter &rewriter) const override {
1468  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1469  if (!operand)
1470  return failure();
1471  unsigned int operandNumber = operand->getOperandNumber();
1472  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1473  VectorType vecType = insertOp.getDestVectorType();
1474  VectorType distrType =
1475  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1476 
1477  // Only supports 1-D or 0-D destinations for now.
1478  if (vecType.getRank() > 1) {
1479  return rewriter.notifyMatchFailure(
1480  insertOp, "only 0-D or 1-D source supported for now");
1481  }
1482 
1483  // Yield destination vector, source scalar and position from warp op.
1484  SmallVector<Value> additionalResults{insertOp.getDest(),
1485  insertOp.getValueToStore()};
1486  SmallVector<Type> additionalResultTypes{
1487  distrType, insertOp.getValueToStore().getType()};
1488  additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1489  additionalResultTypes.append(
1490  SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1491 
1492  Location loc = insertOp.getLoc();
1493  SmallVector<size_t> newRetIndices;
1494  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1495  rewriter, warpOp, additionalResults, additionalResultTypes,
1496  newRetIndices);
1497  rewriter.setInsertionPointAfter(newWarpOp);
1498  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1499  Value newSource = newWarpOp->getResult(newRetIndices[1]);
1500  rewriter.setInsertionPointAfter(newWarpOp);
1501 
1502  OpFoldResult pos;
1503  if (vecType.getRank() != 0) {
1504  int64_t staticPos = insertOp.getStaticPosition()[0];
1505  pos = ShapedType::isDynamic(staticPos)
1506  ? (newWarpOp->getResult(newRetIndices[2]))
1507  : OpFoldResult(rewriter.getIndexAttr(staticPos));
1508  }
1509 
1510  // This condition is always true for 0-d vectors.
1511  if (vecType == distrType) {
1512  Value newInsert;
1513  SmallVector<OpFoldResult> indices;
1514  if (pos) {
1515  indices.push_back(pos);
1516  }
1517  newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
1518  distributedVec, indices);
1519  // Broadcast: Simply move the vector.insert op out.
1520  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1521  newInsert);
1522  return success();
1523  }
1524 
1525  // This is a distribution. Only one lane should insert.
1526  int64_t elementsPerLane = distrType.getShape()[0];
1527  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1528  // tid of extracting thread: pos / elementsPerLane
1529  Value insertingLane = affine::makeComposedAffineApply(
1530  rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1531  // Insert position: pos % elementsPerLane
1533  rewriter, loc, sym0 % elementsPerLane, pos);
1534  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1535  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1536  Value newResult =
1537  rewriter
1538  .create<scf::IfOp>(
1539  loc, isInsertingLane,
1540  /*thenBuilder=*/
1541  [&](OpBuilder &builder, Location loc) {
1542  Value newInsert = builder.create<vector::InsertOp>(
1543  loc, newSource, distributedVec, newPos);
1544  builder.create<scf::YieldOp>(loc, newInsert);
1545  },
1546  /*elseBuilder=*/
1547  [&](OpBuilder &builder, Location loc) {
1548  builder.create<scf::YieldOp>(loc, distributedVec);
1549  })
1550  .getResult(0);
1551  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1552  return success();
1553  }
1554 };
1555 
1556 struct WarpOpInsert : public WarpDistributionPattern {
1557  using Base::Base;
1558  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1559  PatternRewriter &rewriter) const override {
1560  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1561  if (!operand)
1562  return failure();
1563  unsigned int operandNumber = operand->getOperandNumber();
1564  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1565  Location loc = insertOp.getLoc();
1566 
1567  // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1568  if (insertOp.getDestVectorType().getRank() <= 1) {
1569  return failure();
1570  }
1571 
1572  // All following cases are 2d or higher dimensional source vectors.
1573 
1574  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1575  // There is no distribution, this is a broadcast. Simply move the insert
1576  // out of the warp op.
1577  SmallVector<size_t> newRetIndices;
1578  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1579  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1580  {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1581  newRetIndices);
1582  rewriter.setInsertionPointAfter(newWarpOp);
1583  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1584  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1585  Value newResult = rewriter.create<vector::InsertOp>(
1586  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1587  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1588  newResult);
1589  return success();
1590  }
1591 
1592  // Find the distributed dimension. There should be exactly one.
1593  auto distrDestType =
1594  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1595  auto yieldedType = cast<VectorType>(operand->get().getType());
1596  int64_t distrDestDim = -1;
1597  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1598  if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1599  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1600  // support distributing multiple dimensions in the future.
1601  assert(distrDestDim == -1 && "found multiple distributed dims");
1602  distrDestDim = i;
1603  }
1604  }
1605  assert(distrDestDim != -1 && "could not find distributed dimension");
1606 
1607  // Compute the distributed source vector type.
1608  VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1609  SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1610  // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1611  // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1612  // insert a smaller vector<3xf32>.
1613  // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1614  // case, one lane will insert the source vector<96xf32>. The other
1615  // lanes will not do anything.
1616  int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1617  if (distrSrcDim >= 0)
1618  distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1619  auto distrSrcType =
1620  VectorType::get(distrSrcShape, distrDestType.getElementType());
1621 
1622  // Yield source and dest vectors from warp op.
1623  SmallVector<size_t> newRetIndices;
1624  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1625  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1626  {distrSrcType, distrDestType}, newRetIndices);
1627  rewriter.setInsertionPointAfter(newWarpOp);
1628  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1629  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1630 
1631  // Insert into the distributed vector.
1632  Value newResult;
1633  if (distrSrcDim >= 0) {
1634  // Every lane inserts a small piece.
1635  newResult = rewriter.create<vector::InsertOp>(
1636  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1637  } else {
1638  // One lane inserts the entire source vector.
1639  int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1640  SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1641  SmallVector<int64_t> newPos = getAsIntegers(pos);
1642  // tid of inserting lane: pos / elementsPerLane
1643  Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1644  loc, newPos[distrDestDim] / elementsPerLane);
1645  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1646  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1647  // Insert position: pos % elementsPerLane
1648  newPos[distrDestDim] %= elementsPerLane;
1649  auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1650  Value newInsert = builder.create<vector::InsertOp>(
1651  loc, distributedSrc, distributedDest, newPos);
1652  builder.create<scf::YieldOp>(loc, newInsert);
1653  };
1654  auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1655  builder.create<scf::YieldOp>(loc, distributedDest);
1656  };
1657  newResult = rewriter
1658  .create<scf::IfOp>(loc, isInsertingLane,
1659  /*thenBuilder=*/insertingBuilder,
1660  /*elseBuilder=*/nonInsertingBuilder)
1661  .getResult(0);
1662  }
1663 
1664  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1665  return success();
1666  }
1667 };
1668 
1669 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1670 /// the scf.ForOp is the last operation in the region so that it doesn't
1671 /// change the order of execution. This creates a new scf.for region after the
1672 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1673 /// WarpExecuteOnLane0Op region. Example:
1674 /// ```
1675 /// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1676 /// ...
1677 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1678 /// -> (vector<128xf32>) {
1679 /// ...
1680 /// scf.yield %r : vector<128xf32>
1681 /// }
1682 /// gpu.yield %v1 : vector<128xf32>
1683 /// }
1684 /// ```
1685 /// To:
1686 /// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1687 /// ...
1688 /// gpu.yield %v : vector<128xf32>
1689 /// }
1690 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1691 /// -> (vector<4xf32>) {
1692 /// %iw = gpu.warp_execute_on_lane_0(%laneid)
1693 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1694 /// ^bb0(%arg: vector<128xf32>):
1695 /// ...
1696 /// gpu.yield %ir : vector<128xf32>
1697 /// }
1698 /// scf.yield %iw : vector<4xf32>
1699 /// }
1700 /// ```
1701 struct WarpOpScfForOp : public WarpDistributionPattern {
1702 
1703  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1704  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1705  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1706  PatternRewriter &rewriter) const override {
1707  auto warpOpYield = cast<gpu::YieldOp>(
1708  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1709  // Only pick up `ForOp` if it is the last op in the region.
1710  Operation *lastNode = warpOpYield->getPrevNode();
1711  auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1712  if (!forOp)
1713  return failure();
1714  // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
1715  // Those Values need to be returned by the new warp op.
1716  llvm::SmallSetVector<Value, 32> escapingValues;
1717  SmallVector<Type> escapingValueInputTypes;
1718  SmallVector<Type> escapingValueDistTypes;
1720  forOp.getBodyRegion(), [&](OpOperand *operand) {
1721  Operation *parent = operand->get().getParentRegion()->getParentOp();
1722  if (warpOp->isAncestor(parent)) {
1723  if (!escapingValues.insert(operand->get()))
1724  return;
1725  Type distType = operand->get().getType();
1726  if (auto vecType = dyn_cast<VectorType>(distType)) {
1727  AffineMap map = distributionMapFn(operand->get());
1728  distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1729  }
1730  escapingValueInputTypes.push_back(operand->get().getType());
1731  escapingValueDistTypes.push_back(distType);
1732  }
1733  });
1734 
1735  if (llvm::is_contained(escapingValueDistTypes, Type{}))
1736  return failure();
1737  // `WarpOp` can yield two types of values:
1738  // 1. Values that are not results of the `ForOp`:
1739  // These values must also be yielded by the new `WarpOp`. Also, we need
1740  // to record the index mapping for these values to replace them later.
1741  // 2. Values that are results of the `ForOp`:
1742  // In this case, we record the index mapping between the `WarpOp` result
1743  // index and matching `ForOp` result index.
1744  // Additionally, we keep track of the distributed types for all `ForOp`
1745  // vector results.
1746  SmallVector<Value> nonForYieldedValues;
1747  SmallVector<unsigned> nonForResultIndices;
1748  llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
1749  llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
1750  for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1751  // Yielded value is not a result of the forOp.
1752  if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
1753  nonForYieldedValues.push_back(yieldOperand.get());
1754  nonForResultIndices.push_back(yieldOperand.getOperandNumber());
1755  continue;
1756  }
1757  OpResult forResult = cast<OpResult>(yieldOperand.get());
1758  unsigned int forResultNumber = forResult.getResultNumber();
1759  forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber;
1760  // If this `ForOp` result is vector type and it is yielded by the
1761  // `WarpOp`, we keep track the distributed type for this result.
1762  if (!isa<VectorType>(forResult.getType()))
1763  continue;
1764  VectorType distType = cast<VectorType>(
1765  warpOp.getResult(yieldOperand.getOperandNumber()).getType());
1766  forResultDistTypes[forResultNumber] = distType;
1767  }
1768 
1769  // Newly created `WarpOp` will yield values in following order:
1770  // 1. All init args of the `ForOp`.
1771  // 2. All escaping values.
1772  // 3. All non-`ForOp` yielded values.
1773  SmallVector<Value> newWarpOpYieldValues;
1774  SmallVector<Type> newWarpOpDistTypes;
1775  for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
1776  newWarpOpYieldValues.push_back(initArg);
1777  // Compute the distributed type for this init arg.
1778  Type distType = initArg.getType();
1779  if (auto vecType = dyn_cast<VectorType>(distType)) {
1780  // If the `ForOp` result corresponds to this init arg is already yielded
1781  // we can get the distributed type from `forResultDistTypes` map.
1782  // Otherwise, we compute it using distributionMapFn.
1783  AffineMap map = distributionMapFn(initArg);
1784  distType = forResultDistTypes.count(i)
1785  ? forResultDistTypes[i]
1786  : getDistributedType(vecType, map, warpOp.getWarpSize());
1787  }
1788  newWarpOpDistTypes.push_back(distType);
1789  }
1790  // Insert escaping values and their distributed types.
1791  newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
1792  escapingValues.begin(), escapingValues.end());
1793  newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
1794  escapingValueDistTypes.begin(),
1795  escapingValueDistTypes.end());
1796  // Next, we insert all non-`ForOp` yielded values and their distributed
1797  // types. We also create a mapping between the non-`ForOp` yielded value
1798  // index and the corresponding new `WarpOp` yield value index (needed to
1799  // update users later).
1800  llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
1801  for (auto [i, v] :
1802  llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
1803  nonForResultMapping[i] = newWarpOpYieldValues.size();
1804  newWarpOpYieldValues.push_back(v);
1805  newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
1806  }
1807  // Create the new `WarpOp` with the updated yield values and types.
1808  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1809  rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1810 
1811  // Next, we create a new `ForOp` with the init args yielded by the new
1812  // `WarpOp`.
1813  const unsigned escapingValuesStartIdx =
1814  forOp.getInitArgs().size(); // `ForOp` init args are positioned before
1815  // escaping values in the new `WarpOp`.
1816  SmallVector<Value> newForOpOperands;
1817  for (size_t i = 0; i < escapingValuesStartIdx; ++i)
1818  newForOpOperands.push_back(newWarpOp.getResult(i));
1819 
1820  // Create a new `ForOp` outside the new `WarpOp` region.
1821  OpBuilder::InsertionGuard g(rewriter);
1822  rewriter.setInsertionPointAfter(newWarpOp);
1823  auto newForOp = rewriter.create<scf::ForOp>(
1824  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1825  forOp.getStep(), newForOpOperands);
1826  // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
1827  // newly created `ForOp`. This `WarpOp` will contain all ops that were
1828  // contained within the original `ForOp` body.
1829  rewriter.setInsertionPointToStart(newForOp.getBody());
1830 
1831  SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
1832  newForOp.getRegionIterArgs().end());
1833  SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
1834  forOp.getResultTypes().end());
1835  // Escaping values are forwarded to the inner `WarpOp` as its (additional)
1836  // arguments. We keep track of the mapping between these values and their
1837  // argument index in the inner `WarpOp` (to replace users later).
1838  llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1839  for (size_t i = escapingValuesStartIdx;
1840  i < escapingValuesStartIdx + escapingValues.size(); ++i) {
1841  innerWarpInput.push_back(newWarpOp.getResult(i));
1842  argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
1843  innerWarpInputType.size();
1844  innerWarpInputType.push_back(
1845  escapingValueInputTypes[i - escapingValuesStartIdx]);
1846  }
1847  // Create the inner `WarpOp` with the new input values and types.
1848  auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1849  newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1850  newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
1851 
1852  // Inline the `ForOp` body into the inner `WarpOp` body.
1853  SmallVector<Value> argMapping;
1854  argMapping.push_back(newForOp.getInductionVar());
1855  for (Value args : innerWarp.getBody()->getArguments())
1856  argMapping.push_back(args);
1857 
1858  argMapping.resize(forOp.getBody()->getNumArguments());
1859  SmallVector<Value> yieldOperands;
1860  for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1861  yieldOperands.push_back(operand);
1862 
1863  rewriter.eraseOp(forOp.getBody()->getTerminator());
1864  rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1865 
1866  // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
1867  // original `ForOp` results.
1868  rewriter.setInsertionPointToEnd(innerWarp.getBody());
1869  rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
1870  rewriter.setInsertionPointAfter(innerWarp);
1871  // Insert a scf.yield op at the end of the new `ForOp` body that yields
1872  // the inner `WarpOp` results.
1873  if (!innerWarp.getResults().empty())
1874  rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1875 
1876  // Update the users of original `WarpOp` results that were coming from the
1877  // original `ForOp` to the corresponding new `ForOp` result.
1878  for (auto [origIdx, newIdx] : forResultMapping)
1879  rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
1880  newForOp.getResult(newIdx), newForOp);
1881  // Similarly, update any users of the `WarpOp` results that were not
1882  // results of the `ForOp`.
1883  for (auto [origIdx, newIdx] : nonForResultMapping)
1884  rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
1885  newWarpOp.getResult(newIdx));
1886  // Remove the original `WarpOp` and `ForOp`, they should not have any uses
1887  // at this point.
1888  rewriter.eraseOp(forOp);
1889  rewriter.eraseOp(warpOp);
1890  // Update any users of escaping values that were forwarded to the
1891  // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
1892  newForOp.walk([&](Operation *op) {
1893  for (OpOperand &operand : op->getOpOperands()) {
1894  auto it = argIndexMapping.find(operand.get());
1895  if (it == argIndexMapping.end())
1896  continue;
1897  operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1898  }
1899  });
1900 
1901  // Finally, hoist out any now uniform code from the inner `WarpOp`.
1902  mlir::vector::moveScalarUniformCode(innerWarp);
1903  return success();
1904  }
1905 
1906 private:
1907  DistributionMapFn distributionMapFn;
1908 };
1909 
1910 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1911 /// The vector is reduced in parallel. Currently limited to vector size
1912 /// matching the warpOp size. E.g.:
1913 /// ```
1914 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1915 /// %0 = "some_def"() : () -> (vector<32xf32>)
1916 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1917 /// gpu.yield %1 : f32
1918 /// }
1919 /// ```
1920 /// is lowered to:
1921 /// ```
1922 /// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1923 /// %1 = "some_def"() : () -> (vector<32xf32>)
1924 /// gpu.yield %1 : vector<32xf32>
1925 /// }
1926 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
1927 /// %r = ("warp.reduction %a")
1928 /// ```
1929 struct WarpOpReduction : public WarpDistributionPattern {
1930  WarpOpReduction(MLIRContext *context,
1931  DistributedReductionFn distributedReductionFn,
1932  PatternBenefit benefit = 1)
1933  : WarpDistributionPattern(context, benefit),
1934  distributedReductionFn(std::move(distributedReductionFn)) {}
1935 
1936  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1937  PatternRewriter &rewriter) const override {
1938  OpOperand *yieldOperand =
1939  getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1940  if (!yieldOperand)
1941  return failure();
1942 
1943  auto reductionOp =
1944  cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1945  auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1946  // Only rank 1 vectors supported.
1947  if (vectorType.getRank() != 1)
1948  return rewriter.notifyMatchFailure(
1949  warpOp, "Only rank 1 reductions can be distributed.");
1950  // Only warp_size-sized vectors supported.
1951  if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1952  return rewriter.notifyMatchFailure(
1953  warpOp, "Reduction vector dimension must match was size.");
1954  if (!reductionOp.getType().isIntOrFloat())
1955  return rewriter.notifyMatchFailure(
1956  warpOp, "Reduction distribution currently only supports floats and "
1957  "integer types.");
1958 
1959  int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1960  // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1961  unsigned operandIndex = yieldOperand->getOperandNumber();
1962  SmallVector<Value> yieldValues = {reductionOp.getVector()};
1963  SmallVector<Type> retTypes = {
1964  VectorType::get({numElements}, reductionOp.getType())};
1965  if (reductionOp.getAcc()) {
1966  yieldValues.push_back(reductionOp.getAcc());
1967  retTypes.push_back(reductionOp.getAcc().getType());
1968  }
1969  SmallVector<size_t> newRetIndices;
1970  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1971  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1972  rewriter.setInsertionPointAfter(newWarpOp);
1973 
1974  // Obtain data to reduce for a single lane.
1975  Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1976  // Distribute and reduce across threads.
1977  Value fullReduce =
1978  distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1979  reductionOp.getKind(), newWarpOp.getWarpSize());
1980  if (reductionOp.getAcc()) {
1981  fullReduce = vector::makeArithReduction(
1982  rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1983  newWarpOp.getResult(newRetIndices[1]));
1984  }
1985  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1986  return success();
1987  }
1988 
1989 private:
1990  DistributedReductionFn distributedReductionFn;
1991 };
1992 
1993 } // namespace
1994 
1998  patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
1999 }
2000 
2001 void mlir::vector::populateDistributeTransferWriteOpPatterns(
2002  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2003  unsigned maxNumElementsToExtract, PatternBenefit benefit) {
2004  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
2005  maxNumElementsToExtract, benefit);
2006 }
2007 
2008 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2009  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2010  const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
2011  PatternBenefit readBenefit) {
2012  patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
2013  patterns
2014  .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2015  WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
2016  WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
2017  WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>(
2018  patterns.getContext(), benefit);
2019  patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
2020  benefit);
2021  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
2022  benefit);
2023 }
2024 
2025 void mlir::vector::populateDistributeReduction(
2027  const DistributedReductionFn &distributedReductionFn,
2028  PatternBenefit benefit) {
2029  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
2030  benefit);
2031 }
2032 
2033 /// Helper to know if an op can be hoisted out of the region.
2034 static bool canBeHoisted(Operation *op,
2035  function_ref<bool(Value)> definedOutside) {
2036  return llvm::all_of(op->getOperands(), definedOutside) &&
2037  isMemoryEffectFree(op) && op->getNumRegions() == 0;
2038 }
2039 
2040 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2041  Block *body = warpOp.getBody();
2042 
2043  // Keep track of the ops we want to hoist.
2044  llvm::SmallSetVector<Operation *, 8> opsToMove;
2045 
2046  // Helper to check if a value is or will be defined outside of the region.
2047  auto isDefinedOutsideOfBody = [&](Value value) {
2048  auto *definingOp = value.getDefiningOp();
2049  return (definingOp && opsToMove.count(definingOp)) ||
2050  warpOp.isDefinedOutsideOfRegion(value);
2051  };
2052 
2053  // Do not use walk here, as we do not want to go into nested regions and hoist
2054  // operations from there.
2055  for (auto &op : body->without_terminator()) {
2056  bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
2057  return isa<VectorType>(result.getType());
2058  });
2059  if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
2060  opsToMove.insert(&op);
2061  }
2062 
2063  // Move all the ops marked as uniform outside of the region.
2064  for (Operation *op : opsToMove)
2065  op->moveBefore(warpOp);
2066 }
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:411
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:341
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:367
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:50
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:554
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:767
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:700
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:620
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:684
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:596
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1278
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1331
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2798
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:345
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:40
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:631
This represents an operation in an abstracted form, suitable for use with the builder APIs.