MLIR  21.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/AffineExpr.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include <utility>
23 
24 using namespace mlir;
25 using namespace mlir::vector;
26 using namespace mlir::gpu;
27 
28 /// Currently the distribution map is implicit based on the vector shape. In the
29 /// future it will be part of the op.
30 /// Example:
31 /// ```
32 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
33 /// ...
34 /// gpu.yield %3 : vector<32x16x64xf32>
35 /// }
36 /// ```
37 /// Would have an implicit map of:
38 /// `(d0, d1, d2) -> (d0, d2)`
39 static AffineMap calculateImplicitMap(VectorType sequentialType,
40  VectorType distributedType) {
42  perm.reserve(1);
43  // Check which dimensions of the sequential type are different than the
44  // dimensions of the distributed type to know the distributed dimensions. Then
45  // associate each distributed dimension to an ID in order.
46  for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
47  if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
48  perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
49  }
50  auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
51  distributedType.getContext());
52  return map;
53 }
54 
55 namespace {
56 
57 /// Helper struct to create the load / store operations that permit transit
58 /// through the parallel / sequential and the sequential / parallel boundaries
59 /// when performing `rewriteWarpOpToScfFor`.
60 ///
61 /// The vector distribution dimension is inferred from the vector types.
62 struct DistributedLoadStoreHelper {
63  DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
64  Value laneId, Value zero)
65  : sequentialVal(sequentialVal), distributedVal(distributedVal),
66  laneId(laneId), zero(zero) {
67  sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
68  distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
69  if (sequentialVectorType && distributedVectorType)
70  distributionMap =
71  calculateImplicitMap(sequentialVectorType, distributedVectorType);
72  }
73 
74  Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
75  int64_t distributedSize = distributedVectorType.getDimSize(index);
77  return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
78  ArrayRef<Value>{laneId});
79  }
80 
81  /// Create a store during the process of distributing the
82  /// `vector.warp_execute_on_thread_0` op.
83  /// Vector distribution assumes the following convention regarding the
84  /// temporary buffers that are created to transition values. This **must**
85  /// be properly specified in the `options.warpAllocationFn`:
86  /// 1. scalars of type T transit through a memref<1xT>.
87  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
88  Operation *buildStore(RewriterBase &b, Location loc, Value val,
89  Value buffer) {
90  assert((val == distributedVal || val == sequentialVal) &&
91  "Must store either the preregistered distributed or the "
92  "preregistered sequential value.");
93  // Scalar case can directly use memref.store.
94  if (!isa<VectorType>(val.getType()))
95  return b.create<memref::StoreOp>(loc, val, buffer, zero);
96 
97  // Vector case must use vector::TransferWriteOp which will later lower to
98  // vector.store of memref.store depending on further lowerings.
99  int64_t rank = sequentialVectorType.getRank();
100  SmallVector<Value> indices(rank, zero);
101  if (val == distributedVal) {
102  for (auto dimExpr : distributionMap.getResults()) {
103  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
104  indices[index] = buildDistributedOffset(b, loc, index);
105  }
106  }
107  SmallVector<bool> inBounds(indices.size(), true);
108  return b.create<vector::TransferWriteOp>(
109  loc, val, buffer, indices,
110  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
111  }
112 
113  /// Create a load during the process of distributing the
114  /// `vector.warp_execute_on_thread_0` op.
115  /// Vector distribution assumes the following convention regarding the
116  /// temporary buffers that are created to transition values. This **must**
117  /// be properly specified in the `options.warpAllocationFn`:
118  /// 1. scalars of type T transit through a memref<1xT>.
119  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
120  ///
121  /// When broadcastMode is true, the load is not distributed to account for
122  /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
123  ///
124  /// Example:
125  ///
126  /// ```
127  /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
128  /// gpu.yield %cst : f32
129  /// }
130  /// // Both types are f32. The constant %cst is broadcasted to all lanes.
131  /// ```
132  /// This behavior described in more detail in the documentation of the op.
133  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
134 
135  // Scalar case can directly use memref.store.
136  if (!isa<VectorType>(type))
137  return b.create<memref::LoadOp>(loc, buffer, zero);
138 
139  // Other cases must be vector atm.
140  // Vector case must use vector::TransferReadOp which will later lower to
141  // vector.read of memref.read depending on further lowerings.
142  assert((type == distributedVectorType || type == sequentialVectorType) &&
143  "Must store either the preregistered distributed or the "
144  "preregistered sequential type.");
145  SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
146  if (type == distributedVectorType) {
147  for (auto dimExpr : distributionMap.getResults()) {
148  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
149  indices[index] = buildDistributedOffset(b, loc, index);
150  }
151  }
152  SmallVector<bool> inBounds(indices.size(), true);
153  return b.create<vector::TransferReadOp>(
154  loc, cast<VectorType>(type), buffer, indices,
155  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
156  }
157 
158  Value sequentialVal, distributedVal, laneId, zero;
159  VectorType sequentialVectorType, distributedVectorType;
160  AffineMap distributionMap;
161 };
162 
163 } // namespace
164 
165 // Clones `op` into a new operation that takes `operands` and returns
166 // `resultTypes`.
168  Location loc, Operation *op,
169  ArrayRef<Value> operands,
170  ArrayRef<Type> resultTypes) {
171  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
172  op->getAttrs());
173  return rewriter.create(res);
174 }
175 
176 namespace {
177 
178 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
179 /// thread `laneId` executes the entirety of the computation.
180 ///
181 /// After the transformation:
182 /// - the IR within the scf.if op can be thought of as executing sequentially
183 /// (from the point of view of threads along `laneId`).
184 /// - the IR outside of the scf.if op can be thought of as executing in
185 /// parallel (from the point of view of threads along `laneId`).
186 ///
187 /// Values that need to transit through the parallel / sequential and the
188 /// sequential / parallel boundaries do so via reads and writes to a temporary
189 /// memory location.
190 ///
191 /// The transformation proceeds in multiple steps:
192 /// 1. Create the scf.if op.
193 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
194 /// within the scf.if to transit the values captured from above.
195 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
196 /// consistent within the scf.if.
197 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
198 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to
199 /// transit the values returned by the op.
200 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
201 /// consistent after the scf.if.
202 /// 7. Perform late cleanups.
203 ///
204 /// All this assumes the vector distribution occurs along the most minor
205 /// distributed vector dimension.
206 struct WarpOpToScfIfPattern : public WarpDistributionPattern {
207  WarpOpToScfIfPattern(MLIRContext *context,
209  PatternBenefit benefit = 1)
210  : WarpDistributionPattern(context, benefit), options(options) {}
211 
212  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
213  PatternRewriter &rewriter) const override {
214  assert(warpOp.getBodyRegion().hasOneBlock() &&
215  "expected WarpOp with single block");
216  Block *warpOpBody = &warpOp.getBodyRegion().front();
217  Location loc = warpOp.getLoc();
218 
219  // Passed all checks. Start rewriting.
220  OpBuilder::InsertionGuard g(rewriter);
221  rewriter.setInsertionPoint(warpOp);
222 
223  // Step 1: Create scf.if op.
224  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
225  Value isLane0 = rewriter.create<arith::CmpIOp>(
226  loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
227  auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
228  /*withElseRegion=*/false);
229  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
230 
231  // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
232  // reads within the scf.if to transit the values captured from above.
233  SmallVector<Value> bbArgReplacements;
234  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
235  Value sequentialVal = warpOpBody->getArgument(it.index());
236  Value distributedVal = it.value();
237  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
238  warpOp.getLaneid(), c0);
239 
240  // Create buffer before the ifOp.
241  rewriter.setInsertionPoint(ifOp);
242  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
243  sequentialVal.getType());
244  // Store distributed vector into buffer, before the ifOp.
245  helper.buildStore(rewriter, loc, distributedVal, buffer);
246  // Load sequential vector from buffer, inside the ifOp.
247  rewriter.setInsertionPointToStart(ifOp.thenBlock());
248  bbArgReplacements.push_back(
249  helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
250  }
251 
252  // Step 3. Insert sync after all the stores and before all the loads.
253  if (!warpOp.getArgs().empty()) {
254  rewriter.setInsertionPoint(ifOp);
255  options.warpSyncronizationFn(loc, rewriter, warpOp);
256  }
257 
258  // Step 4. Move body of warpOp to ifOp.
259  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
260 
261  // Step 5. Insert appropriate writes within scf.if and reads after the
262  // scf.if to transit the values returned by the op.
263  // TODO: at this point, we can reuse the shared memory from previous
264  // buffers.
265  SmallVector<Value> replacements;
266  auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
267  Location yieldLoc = yieldOp.getLoc();
268  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
269  Value sequentialVal = it.value();
270  Value distributedVal = warpOp->getResult(it.index());
271  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
272  warpOp.getLaneid(), c0);
273 
274  // Create buffer before the ifOp.
275  rewriter.setInsertionPoint(ifOp);
276  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
277  sequentialVal.getType());
278 
279  // Store yielded value into buffer, inside the ifOp, before the
280  // terminator.
281  rewriter.setInsertionPoint(yieldOp);
282  helper.buildStore(rewriter, loc, sequentialVal, buffer);
283 
284  // Load distributed value from buffer, after the warpOp.
285  rewriter.setInsertionPointAfter(ifOp);
286  // Result type and yielded value type are the same. This is a broadcast.
287  // E.g.:
288  // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
289  // gpu.yield %cst : f32
290  // }
291  // Both types are f32. The constant %cst is broadcasted to all lanes.
292  // This is described in more detail in the documentation of the op.
293  replacements.push_back(
294  helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
295  }
296 
297  // Step 6. Insert sync after all the stores and before all the loads.
298  if (!yieldOp.getOperands().empty()) {
299  rewriter.setInsertionPointAfter(ifOp);
300  options.warpSyncronizationFn(loc, rewriter, warpOp);
301  }
302 
303  // Step 7. Delete terminator and add empty scf.yield.
304  rewriter.eraseOp(yieldOp);
305  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
306  rewriter.create<scf::YieldOp>(yieldLoc);
307 
308  // Compute replacements for WarpOp results.
309  rewriter.replaceOp(warpOp, replacements);
310 
311  return success();
312  }
313 
314 private:
316 };
317 
318 /// Return the distributed vector type based on the original type and the
319 /// distribution map. The map is expected to have a dimension equal to the
320 /// original type rank and should be a projection where the results are the
321 /// distributed dimensions. The number of results should be equal to the number
322 /// of warp sizes which is currently limited to 1.
323 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
324 /// and a warp size of 16 would distribute the second dimension (associated to
325 /// d1) and return vector<16x2x64>
326 static VectorType getDistributedType(VectorType originalType, AffineMap map,
327  int64_t warpSize) {
328  SmallVector<int64_t> targetShape(originalType.getShape());
329  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
330  unsigned position = map.getDimPosition(i);
331  if (targetShape[position] % warpSize != 0) {
332  if (warpSize % targetShape[position] != 0) {
333  return VectorType();
334  }
335  warpSize /= targetShape[position];
336  targetShape[position] = 1;
337  continue;
338  }
339  targetShape[position] = targetShape[position] / warpSize;
340  warpSize = 1;
341  break;
342  }
343  if (warpSize != 1) {
344  return VectorType();
345  }
346  VectorType targetType =
347  VectorType::get(targetShape, originalType.getElementType());
348  return targetType;
349 }
350 
351 /// Distribute transfer_write ops based on the affine map returned by
352 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
353 /// will not be distributed (it should be less than the warp size).
354 ///
355 /// Example:
356 /// ```
357 /// %0 = gpu.warp_execute_on_lane_0(%id){
358 /// ...
359 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
360 /// gpu.yield
361 /// }
362 /// ```
363 /// To
364 /// ```
365 /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
366 /// ...
367 /// gpu.yield %v : vector<32xf32>
368 /// }
369 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
370 struct WarpOpTransferWrite : public WarpDistributionPattern {
371  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
372  unsigned maxNumElementsToExtract, PatternBenefit b = 1)
373  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
374  maxNumElementsToExtract(maxNumElementsToExtract) {}
375 
376  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
377  /// are multiples of the distribution ratio are supported at the moment.
378  LogicalResult tryDistributeOp(RewriterBase &rewriter,
379  vector::TransferWriteOp writeOp,
380  WarpExecuteOnLane0Op warpOp) const {
381  VectorType writtenVectorType = writeOp.getVectorType();
382 
383  // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
384  // to separate it from the rest.
385  if (writtenVectorType.getRank() == 0)
386  return failure();
387 
388  // 2. Compute the distributed type.
389  AffineMap map = distributionMapFn(writeOp.getVector());
390  VectorType targetType =
391  getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
392  if (!targetType)
393  return failure();
394 
395  // 2.5 Compute the distributed type for the new mask;
396  VectorType maskType;
397  if (writeOp.getMask()) {
398  // TODO: Distribution of masked writes with non-trivial permutation maps
399  // requires the distribution of the mask to elementwise match the
400  // distribution of the permuted written vector. Currently the details
401  // of which lane is responsible for which element is captured strictly
402  // by shape information on the warp op, and thus requires materializing
403  // the permutation in IR.
404  if (!writeOp.getPermutationMap().isMinorIdentity())
405  return failure();
406  maskType =
407  getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
408  }
409 
410  // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
411  // the rest.
412  vector::TransferWriteOp newWriteOp =
413  cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
414 
415  // 4. Reindex the write using the distribution map.
416  auto newWarpOp =
417  newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
418 
419  // Delinearize the lane id based on the way threads are divided across the
420  // vector. To get the number of threads per vector dimension, divide the
421  // sequential size by the distributed size along each dim.
422  rewriter.setInsertionPoint(newWriteOp);
423  SmallVector<OpFoldResult> delinearizedIdSizes;
424  for (auto [seqSize, distSize] :
425  llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
426  assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
427  delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
428  }
429  SmallVector<Value> delinearized;
430  if (map.getNumResults() > 1) {
431  delinearized = rewriter
432  .create<mlir::affine::AffineDelinearizeIndexOp>(
433  newWarpOp.getLoc(), newWarpOp.getLaneid(),
434  delinearizedIdSizes)
435  .getResults();
436  } else {
437  // If there is only one map result, we can elide the delinearization
438  // op and use the lane id directly.
439  delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
440  }
441 
442  AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
443  Location loc = newWriteOp.getLoc();
444  SmallVector<Value> indices(newWriteOp.getIndices().begin(),
445  newWriteOp.getIndices().end());
446  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
447  AffineExpr d0, d1;
448  bindDims(newWarpOp.getContext(), d0, d1);
449  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
450  if (!indexExpr)
451  continue;
452  unsigned indexPos = indexExpr.getPosition();
453  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
454  Value laneId = delinearized[vectorPos];
455  auto scale =
456  rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
457  indices[indexPos] = affine::makeComposedAffineApply(
458  rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
459  }
460  newWriteOp.getIndicesMutable().assign(indices);
461 
462  return success();
463  }
464 
465  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
466  LogicalResult tryExtractOp(RewriterBase &rewriter,
467  vector::TransferWriteOp writeOp,
468  WarpExecuteOnLane0Op warpOp) const {
469  Location loc = writeOp.getLoc();
470  VectorType vecType = writeOp.getVectorType();
471 
472  if (vecType.getNumElements() > maxNumElementsToExtract) {
473  return rewriter.notifyMatchFailure(
474  warpOp,
475  llvm::formatv(
476  "writes more elements ({0}) than allowed to extract ({1})",
477  vecType.getNumElements(), maxNumElementsToExtract));
478  }
479 
480  // Do not process warp ops that contain only TransferWriteOps.
481  if (llvm::all_of(warpOp.getOps(),
482  llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
483  return failure();
484 
485  SmallVector<Value> yieldValues = {writeOp.getVector()};
486  SmallVector<Type> retTypes = {vecType};
487  SmallVector<size_t> newRetIndices;
488  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
489  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
490  rewriter.setInsertionPointAfter(newWarpOp);
491 
492  // Create a second warp op that contains only writeOp.
493  auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
494  loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
495  Block &body = secondWarpOp.getBodyRegion().front();
496  rewriter.setInsertionPointToStart(&body);
497  auto newWriteOp =
498  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
499  newWriteOp.getValueToStoreMutable().assign(
500  newWarpOp.getResult(newRetIndices[0]));
501  rewriter.eraseOp(writeOp);
502  rewriter.create<gpu::YieldOp>(newWarpOp.getLoc());
503  return success();
504  }
505 
506  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
507  PatternRewriter &rewriter) const override {
508  auto yield = cast<gpu::YieldOp>(
509  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
510  Operation *lastNode = yield->getPrevNode();
511  auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
512  if (!writeOp)
513  return failure();
514 
515  Value maybeMask = writeOp.getMask();
516  if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
517  return writeOp.getVector() == value ||
518  (maybeMask && maybeMask == value) ||
519  warpOp.isDefinedOutsideOfRegion(value);
520  }))
521  return failure();
522 
523  if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
524  return success();
525 
526  // Masked writes not supported for extraction.
527  if (writeOp.getMask())
528  return failure();
529 
530  if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
531  return success();
532 
533  return failure();
534  }
535 
536 private:
537  /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
538  /// execute op with the proper return type. The new write op is updated to
539  /// write the result of the new warp execute op. The old `writeOp` is deleted.
540  vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
541  WarpExecuteOnLane0Op warpOp,
542  vector::TransferWriteOp writeOp,
543  VectorType targetType,
544  VectorType maybeMaskType) const {
545  assert(writeOp->getParentOp() == warpOp &&
546  "write must be nested immediately under warp");
547  OpBuilder::InsertionGuard g(rewriter);
548  SmallVector<size_t> newRetIndices;
549  WarpExecuteOnLane0Op newWarpOp;
550  if (maybeMaskType) {
551  newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
552  rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
553  TypeRange{targetType, maybeMaskType}, newRetIndices);
554  } else {
555  newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
556  rewriter, warpOp, ValueRange{{writeOp.getVector()}},
557  TypeRange{targetType}, newRetIndices);
558  }
559  rewriter.setInsertionPointAfter(newWarpOp);
560  auto newWriteOp =
561  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
562  rewriter.eraseOp(writeOp);
563  newWriteOp.getValueToStoreMutable().assign(
564  newWarpOp.getResult(newRetIndices[0]));
565  if (maybeMaskType)
566  newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
567  return newWriteOp;
568  }
569 
570  DistributionMapFn distributionMapFn;
571  unsigned maxNumElementsToExtract = 1;
572 };
573 
574 /// Sink out elementwise op feeding into a warp op yield.
575 /// ```
576 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
577 /// ...
578 /// %3 = arith.addf %1, %2 : vector<32xf32>
579 /// gpu.yield %3 : vector<32xf32>
580 /// }
581 /// ```
582 /// To
583 /// ```
584 /// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
585 /// vector<1xf32>, vector<1xf32>) {
586 /// ...
587 /// %4 = arith.addf %2, %3 : vector<32xf32>
588 /// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
589 /// vector<32xf32>
590 /// }
591 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
592 struct WarpOpElementwise : public WarpDistributionPattern {
593  using Base::Base;
594  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
595  PatternRewriter &rewriter) const override {
596  OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
598  });
599  if (!yieldOperand)
600  return failure();
601 
602  Operation *elementWise = yieldOperand->get().getDefiningOp();
603  unsigned operandIndex = yieldOperand->getOperandNumber();
604  Value distributedVal = warpOp.getResult(operandIndex);
605  SmallVector<Value> yieldValues;
606  SmallVector<Type> retTypes;
607  Location loc = warpOp.getLoc();
608  for (OpOperand &operand : elementWise->getOpOperands()) {
609  Type targetType;
610  if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
611  // If the result type is a vector, the operands must also be vectors.
612  auto operandType = cast<VectorType>(operand.get().getType());
613  targetType =
614  VectorType::get(vecType.getShape(), operandType.getElementType());
615  } else {
616  auto operandType = operand.get().getType();
617  assert(!isa<VectorType>(operandType) &&
618  "unexpected yield of vector from op with scalar result type");
619  targetType = operandType;
620  }
621  retTypes.push_back(targetType);
622  yieldValues.push_back(operand.get());
623  }
624  SmallVector<size_t> newRetIndices;
625  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
626  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
627  rewriter.setInsertionPointAfter(newWarpOp);
628  SmallVector<Value> newOperands(elementWise->getOperands().begin(),
629  elementWise->getOperands().end());
630  for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
631  newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
632  }
633  OpBuilder::InsertionGuard g(rewriter);
634  rewriter.setInsertionPointAfter(newWarpOp);
636  rewriter, loc, elementWise, newOperands,
637  {newWarpOp.getResult(operandIndex).getType()});
638  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
639  newOp->getResult(0));
640  return success();
641  }
642 };
643 
644 /// Sink out splat constant op feeding into a warp op yield.
645 /// ```
646 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
647 /// ...
648 /// %cst = arith.constant dense<2.0> : vector<32xf32>
649 /// gpu.yield %cst : vector<32xf32>
650 /// }
651 /// ```
652 /// To
653 /// ```
654 /// gpu.warp_execute_on_lane_0(%arg0 {
655 /// ...
656 /// }
657 /// %0 = arith.constant dense<2.0> : vector<1xf32>
658 struct WarpOpConstant : public WarpDistributionPattern {
659  using Base::Base;
660  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
661  PatternRewriter &rewriter) const override {
662  OpOperand *yieldOperand =
663  getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
664  if (!yieldOperand)
665  return failure();
666  auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
667  auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
668  if (!dense)
669  return failure();
670  // Notify the rewriter that the warp op is changing (see the comment on
671  // the WarpOpTransferRead pattern).
672  rewriter.startOpModification(warpOp);
673  unsigned operandIndex = yieldOperand->getOperandNumber();
674  Attribute scalarAttr = dense.getSplatValue<Attribute>();
675  auto newAttr = DenseElementsAttr::get(
676  cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
677  Location loc = warpOp.getLoc();
678  rewriter.setInsertionPointAfter(warpOp);
679  Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
680  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
681  rewriter.finalizeOpModification(warpOp);
682  return success();
683  }
684 };
685 
686 /// Sink out transfer_read op feeding into a warp op yield.
687 /// ```
688 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
689 /// ...
690 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
691 // vector<32xf32>
692 /// gpu.yield %2 : vector<32xf32>
693 /// }
694 /// ```
695 /// To
696 /// ```
697 /// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
698 /// vector<1xf32>, vector<1xf32>) {
699 /// ...
700 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
701 /// vector<32xf32> gpu.yield %2 : vector<32xf32>
702 /// }
703 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
704 struct WarpOpTransferRead : public WarpDistributionPattern {
705  using Base::Base;
706  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
707  PatternRewriter &rewriter) const override {
708  // Try to find a distributable yielded read. Note that this pattern can
709  // still fail at the end after distribution, in which case this might have
710  // missed another distributable read.
711  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
712  // Don't duplicate transfer_read ops when distributing.
713  return isa<vector::TransferReadOp>(op) && op->hasOneUse();
714  });
715  if (!operand)
716  return rewriter.notifyMatchFailure(
717  warpOp, "warp result is not a vector.transfer_read op");
718  auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
719 
720  // Source must be defined outside of the region.
721  if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
722  return rewriter.notifyMatchFailure(
723  read, "source must be defined outside of the region");
724 
725  unsigned operandIndex = operand->getOperandNumber();
726  Value distributedVal = warpOp.getResult(operandIndex);
727 
728  SmallVector<Value, 4> indices(read.getIndices().begin(),
729  read.getIndices().end());
730  auto sequentialType = cast<VectorType>(read.getResult().getType());
731  auto distributedType = cast<VectorType>(distributedVal.getType());
732  AffineMap map = calculateImplicitMap(sequentialType, distributedType);
733  AffineMap indexMap = map.compose(read.getPermutationMap());
734 
735  // Try to delinearize the lane ID to match the rank expected for
736  // distribution.
737  SmallVector<Value> delinearizedIds;
738  if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
739  distributedType.getShape(), warpOp.getWarpSize(),
740  warpOp.getLaneid(), delinearizedIds)) {
741  return rewriter.notifyMatchFailure(
742  read, "cannot delinearize lane ID for distribution");
743  }
744  assert(!delinearizedIds.empty() || map.getNumResults() == 0);
745 
746  // Distribute indices and the mask (if present).
747  OpBuilder::InsertionGuard g(rewriter);
748  SmallVector<Value> additionalResults(indices.begin(), indices.end());
749  SmallVector<Type> additionalResultTypes(indices.size(),
750  rewriter.getIndexType());
751  additionalResults.push_back(read.getPadding());
752  additionalResultTypes.push_back(read.getPadding().getType());
753 
754  bool hasMask = false;
755  if (read.getMask()) {
756  hasMask = true;
757  // TODO: Distribution of masked reads with non-trivial permutation maps
758  // requires the distribution of the mask to elementwise match the
759  // distribution of the permuted written vector. Currently the details
760  // of which lane is responsible for which element is captured strictly
761  // by shape information on the warp op, and thus requires materializing
762  // the permutation in IR.
763  if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
764  return rewriter.notifyMatchFailure(
765  read, "non-trivial permutation maps not supported");
766  VectorType maskType =
767  getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
768  additionalResults.push_back(read.getMask());
769  additionalResultTypes.push_back(maskType);
770  }
771 
772  SmallVector<size_t> newRetIndices;
773  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
774  rewriter, warpOp, additionalResults, additionalResultTypes,
775  newRetIndices);
776  distributedVal = newWarpOp.getResult(operandIndex);
777 
778  // Distributed indices were appended first.
779  SmallVector<Value> newIndices;
780  for (int64_t i = 0, e = indices.size(); i < e; ++i)
781  newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
782 
783  rewriter.setInsertionPointAfter(newWarpOp);
784  for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
785  AffineExpr d0, d1;
786  bindDims(read.getContext(), d0, d1);
787  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
788  if (!indexExpr)
789  continue;
790  unsigned indexPos = indexExpr.getPosition();
791  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
792  int64_t scale = distributedType.getDimSize(vectorPos);
793  newIndices[indexPos] = affine::makeComposedAffineApply(
794  rewriter, read.getLoc(), d0 + scale * d1,
795  {newIndices[indexPos], delinearizedIds[vectorPos]});
796  }
797 
798  // Distributed padding value was appended right after the indices.
799  Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
800  // Distributed mask value was added at the end (if the op has a mask).
801  Value newMask =
802  hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
803  : Value();
804  auto newRead = rewriter.create<vector::TransferReadOp>(
805  read.getLoc(), distributedVal.getType(), read.getBase(), newIndices,
806  read.getPermutationMapAttr(), newPadding, newMask,
807  read.getInBoundsAttr());
808 
809  rewriter.replaceAllUsesWith(distributedVal, newRead);
810  return success();
811  }
812 };
813 
814 /// Remove any result that has no use along with the matching yieldOp operand.
815 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
816 struct WarpOpDeadResult : public WarpDistributionPattern {
817  using Base::Base;
818  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
819  PatternRewriter &rewriter) const override {
820  SmallVector<Type> newResultTypes;
821  newResultTypes.reserve(warpOp->getNumResults());
822  SmallVector<Value> newYieldValues;
823  newYieldValues.reserve(warpOp->getNumResults());
824  DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
825  DenseMap<OpResult, int64_t> dedupResultPositionMap;
826  auto yield = cast<gpu::YieldOp>(
827  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
828 
829  // Some values may be yielded multiple times and correspond to multiple
830  // results. Deduplicating occurs by taking each result with its matching
831  // yielded value, and:
832  // 1. recording the unique first position at which the value is yielded.
833  // 2. recording for the result, the first position at which the dedup'ed
834  // value is yielded.
835  // 3. skipping from the new result types / new yielded values any result
836  // that has no use or whose yielded value has already been seen.
837  for (OpResult result : warpOp.getResults()) {
838  Value yieldOperand = yield.getOperand(result.getResultNumber());
839  auto it = dedupYieldOperandPositionMap.insert(
840  std::make_pair(yieldOperand, newResultTypes.size()));
841  dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
842  if (result.use_empty() || !it.second)
843  continue;
844  newResultTypes.push_back(result.getType());
845  newYieldValues.push_back(yieldOperand);
846  }
847  // No modification, exit early.
848  if (yield.getNumOperands() == newYieldValues.size())
849  return failure();
850  // Move the body of the old warpOp to a new warpOp.
851  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
852  rewriter, warpOp, newYieldValues, newResultTypes);
853 
854  // Simplify the new warp op after dropping dead results.
855  newWarpOp.getBody()->walk([&](Operation *op) {
856  if (isOpTriviallyDead(op))
857  rewriter.eraseOp(op);
858  });
859 
860  // Replace results of the old warpOp by the new, deduplicated results.
861  SmallVector<Value> newValues;
862  newValues.reserve(warpOp->getNumResults());
863  for (OpResult result : warpOp.getResults()) {
864  if (result.use_empty())
865  newValues.push_back(Value());
866  else
867  newValues.push_back(
868  newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
869  }
870  rewriter.replaceOp(warpOp, newValues);
871  return success();
872  }
873 };
874 
875 // If an operand is directly yielded out of the region we can forward it
876 // directly and it doesn't need to go through the region.
877 struct WarpOpForwardOperand : public WarpDistributionPattern {
878  using Base::Base;
879  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
880  PatternRewriter &rewriter) const override {
881  auto yield = cast<gpu::YieldOp>(
882  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
883  Value valForwarded;
884  unsigned resultIndex;
885  for (OpOperand &operand : yield->getOpOperands()) {
886  Value result = warpOp.getResult(operand.getOperandNumber());
887  if (result.use_empty())
888  continue;
889 
890  // Assume all the values coming from above are uniform.
891  if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
892  if (result.getType() != operand.get().getType())
893  continue;
894  valForwarded = operand.get();
895  resultIndex = operand.getOperandNumber();
896  break;
897  }
898  auto arg = dyn_cast<BlockArgument>(operand.get());
899  if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
900  continue;
901  Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
902  if (result.getType() != warpOperand.getType())
903  continue;
904  valForwarded = warpOperand;
905  resultIndex = operand.getOperandNumber();
906  break;
907  }
908  if (!valForwarded)
909  return failure();
910  // Notify the rewriter that the warp op is changing (see the comment on
911  // the WarpOpTransferRead pattern).
912  rewriter.startOpModification(warpOp);
913  rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
914  rewriter.finalizeOpModification(warpOp);
915  return success();
916  }
917 };
918 
919 struct WarpOpBroadcast : public WarpDistributionPattern {
920  using Base::Base;
921  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
922  PatternRewriter &rewriter) const override {
923  OpOperand *operand =
924  getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
925  if (!operand)
926  return failure();
927  unsigned int operandNumber = operand->getOperandNumber();
928  auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
929  Location loc = broadcastOp.getLoc();
930  auto destVecType =
931  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
932  Value broadcastSrc = broadcastOp.getSource();
933  Type broadcastSrcType = broadcastSrc.getType();
934 
935  // Check that the broadcast actually spans a set of values uniformly across
936  // all threads. In other words, check that each thread can reconstruct
937  // their own broadcast.
938  // For that we simply check that the broadcast we want to build makes sense.
939  if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
941  return failure();
942  SmallVector<size_t> newRetIndices;
943  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
944  rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
945  rewriter.setInsertionPointAfter(newWarpOp);
946  Value broadcasted = rewriter.create<vector::BroadcastOp>(
947  loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
948  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
949  broadcasted);
950  return success();
951  }
952 };
953 
954 /// Pattern to move shape cast out of the warp op. shape cast is basically a
955 /// no-op for warp distribution; we need to handle the shape though.
956 struct WarpOpShapeCast : public WarpDistributionPattern {
957  using Base::Base;
958  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
959  PatternRewriter &rewriter) const override {
960  OpOperand *operand =
961  getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
962  if (!operand)
963  return failure();
964 
965  auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
966 
967  unsigned int operandNumber = operand->getOperandNumber();
968  auto castDistributedType =
969  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
970  VectorType castOriginalType = oldCastOp.getSourceVectorType();
971  VectorType castResultType = castDistributedType;
972 
973  // We expect the distributed type to have a smaller rank than the original
974  // type. Prepend with size-one dimensions to make them the same.
975  unsigned castDistributedRank = castDistributedType.getRank();
976  unsigned castOriginalRank = castOriginalType.getRank();
977  if (castDistributedRank < castOriginalRank) {
978  SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
979  llvm::append_range(shape, castDistributedType.getShape());
980  castDistributedType =
981  VectorType::get(shape, castDistributedType.getElementType());
982  }
983 
984  SmallVector<size_t> newRetIndices;
985  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
986  rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
987  newRetIndices);
988  rewriter.setInsertionPointAfter(newWarpOp);
989  Value newCast = rewriter.create<vector::ShapeCastOp>(
990  oldCastOp.getLoc(), castResultType,
991  newWarpOp->getResult(newRetIndices[0]));
992  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
993  return success();
994  }
995 };
996 
997 /// Sink out vector.create_mask op feeding into a warp op yield.
998 /// ```
999 /// %0 = ...
1000 /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1001 /// ...
1002 /// %mask = vector.create_mask %0 : vector<32xi1>
1003 /// gpu.yield %mask : vector<32xi1>
1004 /// }
1005 /// ```
1006 /// To
1007 /// ```
1008 /// %0 = ...
1009 /// gpu.warp_execute_on_lane_0(%arg0) {
1010 /// ...
1011 /// }
1012 /// %cmp = arith.cmpi ult, %laneid, %0
1013 /// %ub = arith.select %cmp, %c0, %c1
1014 /// %1 = vector.create_mask %ub : vector<1xi1>
1015 struct WarpOpCreateMask : public WarpDistributionPattern {
1016  using Base::Base;
1017  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1018  PatternRewriter &rewriter) const override {
1019  OpOperand *yieldOperand =
1020  getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1021  if (!yieldOperand)
1022  return failure();
1023 
1024  auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1025 
1026  // Early exit if any values needed for calculating the new mask indices
1027  // are defined inside the warp op.
1028  if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1029  return warpOp.isDefinedOutsideOfRegion(value);
1030  }))
1031  return failure();
1032 
1033  Location loc = mask.getLoc();
1034  unsigned operandIndex = yieldOperand->getOperandNumber();
1035 
1036  auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1037  VectorType seqType = mask.getVectorType();
1038  ArrayRef<int64_t> seqShape = seqType.getShape();
1039  ArrayRef<int64_t> distShape = distType.getShape();
1040 
1041  rewriter.setInsertionPointAfter(warpOp);
1042 
1043  // Delinearize the lane ID for constructing the distributed mask sizes.
1044  SmallVector<Value> delinearizedIds;
1045  if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1046  warpOp.getWarpSize(), warpOp.getLaneid(),
1047  delinearizedIds))
1048  return rewriter.notifyMatchFailure(
1049  mask, "cannot delinearize lane ID for distribution");
1050  assert(!delinearizedIds.empty());
1051 
1052  // Notify the rewriter that the warp op is changing (see the comment on
1053  // the WarpOpTransferRead pattern).
1054  rewriter.startOpModification(warpOp);
1055 
1056  AffineExpr s0, s1;
1057  bindSymbols(rewriter.getContext(), s0, s1);
1058  SmallVector<Value> newOperands;
1059  for (int i = 0, e = distShape.size(); i < e; ++i) {
1060  // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1061  // find the distance from the largest mask index owned by this lane to the
1062  // original mask size. `vector.create_mask` implicitly clamps mask
1063  // operands to the range [0, mask_vector_size[i]], or in other words, the
1064  // mask sizes are always in the range [0, mask_vector_size[i]).
1066  rewriter, loc, s1 - s0 * distShape[i],
1067  {delinearizedIds[i], mask.getOperand(i)});
1068  newOperands.push_back(maskDimIdx);
1069  }
1070 
1071  auto newMask =
1072  rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1073  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1074  rewriter.finalizeOpModification(warpOp);
1075  return success();
1076  }
1077 };
1078 
1079 /// Pattern to move out vector.extract of single element vector. Those don't
1080 /// need to be distributed and can just be propagated outside of the region.
1081 struct WarpOpExtract : public WarpDistributionPattern {
1082  using Base::Base;
1083  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1084  PatternRewriter &rewriter) const override {
1085  OpOperand *operand =
1086  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1087  if (!operand)
1088  return failure();
1089  unsigned int operandNumber = operand->getOperandNumber();
1090  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1091  VectorType extractSrcType = extractOp.getSourceVectorType();
1092  Location loc = extractOp.getLoc();
1093 
1094  // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1095  if (extractSrcType.getRank() <= 1) {
1096  return failure();
1097  }
1098 
1099  // All following cases are 2d or higher dimensional source vectors.
1100 
1101  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1102  // There is no distribution, this is a broadcast. Simply move the extract
1103  // out of the warp op.
1104  // TODO: This could be optimized. E.g., in case of a scalar result, let
1105  // one lane extract and shuffle the result to all other lanes (same as
1106  // the 1d case).
1107  SmallVector<size_t> newRetIndices;
1108  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1109  rewriter, warpOp, {extractOp.getVector()},
1110  {extractOp.getSourceVectorType()}, newRetIndices);
1111  rewriter.setInsertionPointAfter(newWarpOp);
1112  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1113  // Extract from distributed vector.
1114  Value newExtract = rewriter.create<vector::ExtractOp>(
1115  loc, distributedVec, extractOp.getMixedPosition());
1116  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1117  newExtract);
1118  return success();
1119  }
1120 
1121  // Find the distributed dimension. There should be exactly one.
1122  auto distributedType =
1123  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1124  auto yieldedType = cast<VectorType>(operand->get().getType());
1125  int64_t distributedDim = -1;
1126  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1127  if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1128  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1129  // support distributing multiple dimensions in the future.
1130  assert(distributedDim == -1 && "found multiple distributed dims");
1131  distributedDim = i;
1132  }
1133  }
1134  assert(distributedDim != -1 && "could not find distributed dimension");
1135  (void)distributedDim;
1136 
1137  // Yield source vector from warp op.
1138  SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1139  for (int i = 0; i < distributedType.getRank(); ++i)
1140  newDistributedShape[i + extractOp.getNumIndices()] =
1141  distributedType.getDimSize(i);
1142  auto newDistributedType =
1143  VectorType::get(newDistributedShape, distributedType.getElementType());
1144  SmallVector<size_t> newRetIndices;
1145  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1146  rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1147  newRetIndices);
1148  rewriter.setInsertionPointAfter(newWarpOp);
1149  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1150  // Extract from distributed vector.
1151  Value newExtract = rewriter.create<vector::ExtractOp>(
1152  loc, distributedVec, extractOp.getMixedPosition());
1153  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1154  newExtract);
1155  return success();
1156  }
1157 };
1158 
1159 /// Pattern to move out vector.extract with a scalar result.
1160 /// Only supports 1-D and 0-D sources for now.
1161 struct WarpOpExtractScalar : public WarpDistributionPattern {
1162  WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1163  PatternBenefit b = 1)
1164  : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1165  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1166  PatternRewriter &rewriter) const override {
1167  OpOperand *operand =
1168  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1169  if (!operand)
1170  return failure();
1171  unsigned int operandNumber = operand->getOperandNumber();
1172  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1173  VectorType extractSrcType = extractOp.getSourceVectorType();
1174  // Only supports 1-D or 0-D sources for now.
1175  if (extractSrcType.getRank() > 1) {
1176  return rewriter.notifyMatchFailure(
1177  extractOp, "only 0-D or 1-D source supported for now");
1178  }
1179  // TODO: Supported shuffle types should be parameterizable, similar to
1180  // `WarpShuffleFromIdxFn`.
1181  if (!extractSrcType.getElementType().isF32() &&
1182  !extractSrcType.getElementType().isInteger(32))
1183  return rewriter.notifyMatchFailure(
1184  extractOp, "only f32/i32 element types are supported");
1185  bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1186  Type elType = extractSrcType.getElementType();
1187  VectorType distributedVecType;
1188  if (!is0dOrVec1Extract) {
1189  assert(extractSrcType.getRank() == 1 &&
1190  "expected that extract src rank is 0 or 1");
1191  if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1192  return failure();
1193  int64_t elementsPerLane =
1194  extractSrcType.getShape()[0] / warpOp.getWarpSize();
1195  distributedVecType = VectorType::get({elementsPerLane}, elType);
1196  } else {
1197  distributedVecType = extractSrcType;
1198  }
1199  // Yield source vector and position (if present) from warp op.
1200  SmallVector<Value> additionalResults{extractOp.getVector()};
1201  SmallVector<Type> additionalResultTypes{distributedVecType};
1202  additionalResults.append(
1203  SmallVector<Value>(extractOp.getDynamicPosition()));
1204  additionalResultTypes.append(
1205  SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1206 
1207  Location loc = extractOp.getLoc();
1208  SmallVector<size_t> newRetIndices;
1209  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1210  rewriter, warpOp, additionalResults, additionalResultTypes,
1211  newRetIndices);
1212  rewriter.setInsertionPointAfter(newWarpOp);
1213  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1214 
1215  // 0d extract: The new warp op broadcasts the source vector to all lanes.
1216  // All lanes extract the scalar.
1217  if (is0dOrVec1Extract) {
1218  Value newExtract;
1219  SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1220  newExtract =
1221  rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
1222  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1223  newExtract);
1224  return success();
1225  }
1226 
1227  int64_t staticPos = extractOp.getStaticPosition()[0];
1228  OpFoldResult pos = ShapedType::isDynamic(staticPos)
1229  ? (newWarpOp->getResult(newRetIndices[1]))
1230  : OpFoldResult(rewriter.getIndexAttr(staticPos));
1231  // 1d extract: Distribute the source vector. One lane extracts and shuffles
1232  // the value to all other lanes.
1233  int64_t elementsPerLane = distributedVecType.getShape()[0];
1234  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1235  // tid of extracting thread: pos / elementsPerLane
1236  Value broadcastFromTid = affine::makeComposedAffineApply(
1237  rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1238  // Extract at position: pos % elementsPerLane
1239  Value newPos =
1240  elementsPerLane == 1
1241  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1242  : affine::makeComposedAffineApply(rewriter, loc,
1243  sym0 % elementsPerLane, pos);
1244  Value extracted =
1245  rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
1246 
1247  // Shuffle the extracted value to all lanes.
1248  Value shuffled = warpShuffleFromIdxFn(
1249  loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1250  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1251  return success();
1252  }
1253 
1254 private:
1255  WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1256 };
1257 
1258 /// Pattern to convert vector.extractelement to vector.extract.
1259 struct WarpOpExtractElement : public WarpDistributionPattern {
1260  using Base::Base;
1261  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1262  PatternRewriter &rewriter) const override {
1263  OpOperand *operand =
1264  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1265  if (!operand)
1266  return failure();
1267  auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1268  SmallVector<OpFoldResult> indices;
1269  if (auto pos = extractOp.getPosition()) {
1270  indices.push_back(pos);
1271  }
1272  rewriter.setInsertionPoint(extractOp);
1273  rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1274  extractOp, extractOp.getVector(), indices);
1275  return success();
1276  }
1277 };
1278 
1279 /// Pattern to move out vector.insert with a scalar input.
1280 /// Only supports 1-D and 0-D destinations for now.
1281 struct WarpOpInsertScalar : public WarpDistributionPattern {
1282  using Base::Base;
1283  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1284  PatternRewriter &rewriter) const override {
1285  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1286  if (!operand)
1287  return failure();
1288  unsigned int operandNumber = operand->getOperandNumber();
1289  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1290  VectorType vecType = insertOp.getDestVectorType();
1291  VectorType distrType =
1292  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1293 
1294  // Only supports 1-D or 0-D destinations for now.
1295  if (vecType.getRank() > 1) {
1296  return rewriter.notifyMatchFailure(
1297  insertOp, "only 0-D or 1-D source supported for now");
1298  }
1299 
1300  // Yield destination vector, source scalar and position from warp op.
1301  SmallVector<Value> additionalResults{insertOp.getDest(),
1302  insertOp.getValueToStore()};
1303  SmallVector<Type> additionalResultTypes{
1304  distrType, insertOp.getValueToStore().getType()};
1305  additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1306  additionalResultTypes.append(
1307  SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1308 
1309  Location loc = insertOp.getLoc();
1310  SmallVector<size_t> newRetIndices;
1311  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1312  rewriter, warpOp, additionalResults, additionalResultTypes,
1313  newRetIndices);
1314  rewriter.setInsertionPointAfter(newWarpOp);
1315  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1316  Value newSource = newWarpOp->getResult(newRetIndices[1]);
1317  rewriter.setInsertionPointAfter(newWarpOp);
1318 
1319  OpFoldResult pos;
1320  if (vecType.getRank() != 0) {
1321  int64_t staticPos = insertOp.getStaticPosition()[0];
1322  pos = ShapedType::isDynamic(staticPos)
1323  ? (newWarpOp->getResult(newRetIndices[2]))
1324  : OpFoldResult(rewriter.getIndexAttr(staticPos));
1325  }
1326 
1327  // This condition is always true for 0-d vectors.
1328  if (vecType == distrType) {
1329  Value newInsert;
1330  SmallVector<OpFoldResult> indices;
1331  if (pos) {
1332  indices.push_back(pos);
1333  }
1334  newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
1335  distributedVec, indices);
1336  // Broadcast: Simply move the vector.insert op out.
1337  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1338  newInsert);
1339  return success();
1340  }
1341 
1342  // This is a distribution. Only one lane should insert.
1343  int64_t elementsPerLane = distrType.getShape()[0];
1344  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1345  // tid of extracting thread: pos / elementsPerLane
1346  Value insertingLane = affine::makeComposedAffineApply(
1347  rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1348  // Insert position: pos % elementsPerLane
1350  rewriter, loc, sym0 % elementsPerLane, pos);
1351  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1352  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1353  Value newResult =
1354  rewriter
1355  .create<scf::IfOp>(
1356  loc, isInsertingLane,
1357  /*thenBuilder=*/
1358  [&](OpBuilder &builder, Location loc) {
1359  Value newInsert = builder.create<vector::InsertOp>(
1360  loc, newSource, distributedVec, newPos);
1361  builder.create<scf::YieldOp>(loc, newInsert);
1362  },
1363  /*elseBuilder=*/
1364  [&](OpBuilder &builder, Location loc) {
1365  builder.create<scf::YieldOp>(loc, distributedVec);
1366  })
1367  .getResult(0);
1368  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1369  return success();
1370  }
1371 };
1372 
1373 struct WarpOpInsert : public WarpDistributionPattern {
1374  using Base::Base;
1375  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1376  PatternRewriter &rewriter) const override {
1377  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1378  if (!operand)
1379  return failure();
1380  unsigned int operandNumber = operand->getOperandNumber();
1381  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1382  Location loc = insertOp.getLoc();
1383 
1384  // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1385  if (insertOp.getDestVectorType().getRank() <= 1) {
1386  return failure();
1387  }
1388 
1389  // All following cases are 2d or higher dimensional source vectors.
1390 
1391  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1392  // There is no distribution, this is a broadcast. Simply move the insert
1393  // out of the warp op.
1394  SmallVector<size_t> newRetIndices;
1395  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1396  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1397  {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1398  newRetIndices);
1399  rewriter.setInsertionPointAfter(newWarpOp);
1400  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1401  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1402  Value newResult = rewriter.create<vector::InsertOp>(
1403  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1404  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1405  newResult);
1406  return success();
1407  }
1408 
1409  // Find the distributed dimension. There should be exactly one.
1410  auto distrDestType =
1411  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1412  auto yieldedType = cast<VectorType>(operand->get().getType());
1413  int64_t distrDestDim = -1;
1414  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1415  if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1416  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1417  // support distributing multiple dimensions in the future.
1418  assert(distrDestDim == -1 && "found multiple distributed dims");
1419  distrDestDim = i;
1420  }
1421  }
1422  assert(distrDestDim != -1 && "could not find distributed dimension");
1423 
1424  // Compute the distributed source vector type.
1425  VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1426  SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1427  // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1428  // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1429  // insert a smaller vector<3xf32>.
1430  // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1431  // case, one lane will insert the source vector<96xf32>. The other
1432  // lanes will not do anything.
1433  int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1434  if (distrSrcDim >= 0)
1435  distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1436  auto distrSrcType =
1437  VectorType::get(distrSrcShape, distrDestType.getElementType());
1438 
1439  // Yield source and dest vectors from warp op.
1440  SmallVector<size_t> newRetIndices;
1441  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1442  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1443  {distrSrcType, distrDestType}, newRetIndices);
1444  rewriter.setInsertionPointAfter(newWarpOp);
1445  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1446  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1447 
1448  // Insert into the distributed vector.
1449  Value newResult;
1450  if (distrSrcDim >= 0) {
1451  // Every lane inserts a small piece.
1452  newResult = rewriter.create<vector::InsertOp>(
1453  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1454  } else {
1455  // One lane inserts the entire source vector.
1456  int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1457  SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1458  SmallVector<int64_t> newPos = getAsIntegers(pos);
1459  // tid of inserting lane: pos / elementsPerLane
1460  Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1461  loc, newPos[distrDestDim] / elementsPerLane);
1462  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1463  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1464  // Insert position: pos % elementsPerLane
1465  newPos[distrDestDim] %= elementsPerLane;
1466  auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1467  Value newInsert = builder.create<vector::InsertOp>(
1468  loc, distributedSrc, distributedDest, newPos);
1469  builder.create<scf::YieldOp>(loc, newInsert);
1470  };
1471  auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1472  builder.create<scf::YieldOp>(loc, distributedDest);
1473  };
1474  newResult = rewriter
1475  .create<scf::IfOp>(loc, isInsertingLane,
1476  /*thenBuilder=*/insertingBuilder,
1477  /*elseBuilder=*/nonInsertingBuilder)
1478  .getResult(0);
1479  }
1480 
1481  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1482  return success();
1483  }
1484 };
1485 
1486 struct WarpOpInsertElement : public WarpDistributionPattern {
1487  using Base::Base;
1488  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1489  PatternRewriter &rewriter) const override {
1490  OpOperand *operand =
1491  getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1492  if (!operand)
1493  return failure();
1494  auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1495  SmallVector<OpFoldResult> indices;
1496  if (auto pos = insertOp.getPosition()) {
1497  indices.push_back(pos);
1498  }
1499  rewriter.setInsertionPoint(insertOp);
1500  rewriter.replaceOpWithNewOp<vector::InsertOp>(
1501  insertOp, insertOp.getSource(), insertOp.getDest(), indices);
1502  return success();
1503  }
1504 };
1505 
1506 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1507 /// the scf.ForOp is the last operation in the region so that it doesn't
1508 /// change the order of execution. This creates a new scf.for region after the
1509 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1510 /// WarpExecuteOnLane0Op region. Example:
1511 /// ```
1512 /// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1513 /// ...
1514 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1515 /// -> (vector<128xf32>) {
1516 /// ...
1517 /// scf.yield %r : vector<128xf32>
1518 /// }
1519 /// gpu.yield %v1 : vector<128xf32>
1520 /// }
1521 /// ```
1522 /// To:
1523 /// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1524 /// ...
1525 /// gpu.yield %v : vector<128xf32>
1526 /// }
1527 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1528 /// -> (vector<4xf32>) {
1529 /// %iw = gpu.warp_execute_on_lane_0(%laneid)
1530 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1531 /// ^bb0(%arg: vector<128xf32>):
1532 /// ...
1533 /// gpu.yield %ir : vector<128xf32>
1534 /// }
1535 /// scf.yield %iw : vector<4xf32>
1536 /// }
1537 /// ```
1538 struct WarpOpScfForOp : public WarpDistributionPattern {
1539 
1540  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1541  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1542  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1543  PatternRewriter &rewriter) const override {
1544  auto yield = cast<gpu::YieldOp>(
1545  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1546  // Only pick up forOp if it is the last op in the region.
1547  Operation *lastNode = yield->getPrevNode();
1548  auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1549  if (!forOp)
1550  return failure();
1551  // Collect Values that come from the warp op but are outside the forOp.
1552  // Those Value needs to be returned by the original warpOp and passed to
1553  // the new op.
1554  llvm::SmallSetVector<Value, 32> escapingValues;
1555  SmallVector<Type> inputTypes;
1556  SmallVector<Type> distTypes;
1557  auto collectEscapingValues = [&](Value value) {
1558  if (!escapingValues.insert(value))
1559  return;
1560  Type distType = value.getType();
1561  if (auto vecType = dyn_cast<VectorType>(distType)) {
1562  AffineMap map = distributionMapFn(value);
1563  distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1564  }
1565  inputTypes.push_back(value.getType());
1566  distTypes.push_back(distType);
1567  };
1568 
1570  forOp.getBodyRegion(), [&](OpOperand *operand) {
1571  Operation *parent = operand->get().getParentRegion()->getParentOp();
1572  if (warpOp->isAncestor(parent)) {
1573  collectEscapingValues(operand->get());
1574  }
1575  });
1576 
1577  // Any forOp result that is not already yielded by the warpOp
1578  // region is also considered escaping and must be returned by the
1579  // original warpOp.
1580  for (OpResult forResult : forOp.getResults()) {
1581  // Check if this forResult is already yielded by the yield op.
1582  if (llvm::is_contained(yield->getOperands(), forResult))
1583  continue;
1584  collectEscapingValues(forResult);
1585  }
1586 
1587  if (llvm::is_contained(distTypes, Type{}))
1588  return failure();
1589 
1590  SmallVector<size_t> newRetIndices;
1591  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1592  rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1593  newRetIndices);
1594  yield = cast<gpu::YieldOp>(
1595  newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1596 
1597  SmallVector<Value> newOperands;
1598  SmallVector<unsigned> resultIdx;
1599  // Collect all the outputs coming from the forOp.
1600  for (OpOperand &yieldOperand : yield->getOpOperands()) {
1601  if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1602  continue;
1603  auto forResult = cast<OpResult>(yieldOperand.get());
1604  newOperands.push_back(
1605  newWarpOp.getResult(yieldOperand.getOperandNumber()));
1606  yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1607  resultIdx.push_back(yieldOperand.getOperandNumber());
1608  }
1609 
1610  OpBuilder::InsertionGuard g(rewriter);
1611  rewriter.setInsertionPointAfter(newWarpOp);
1612 
1613  // Create a new for op outside the region with a WarpExecuteOnLane0Op
1614  // region inside.
1615  auto newForOp = rewriter.create<scf::ForOp>(
1616  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1617  forOp.getStep(), newOperands);
1618  rewriter.setInsertionPointToStart(newForOp.getBody());
1619 
1620  SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1621  newForOp.getRegionIterArgs().end());
1622  SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1623  forOp.getResultTypes().end());
1624  llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1625  for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1626  auto newWarpResult = newWarpOp.getResult(retIdx);
1627  // Unused forOp results yielded by the warpOp region are already included
1628  // in the new ForOp.
1629  if (llvm::is_contained(newOperands, newWarpResult))
1630  continue;
1631  warpInput.push_back(newWarpResult);
1632  argIndexMapping[escapingValues[i]] = warpInputType.size();
1633  warpInputType.push_back(inputTypes[i]);
1634  }
1635  auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1636  newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1637  newWarpOp.getWarpSize(), warpInput, warpInputType);
1638 
1639  SmallVector<Value> argMapping;
1640  argMapping.push_back(newForOp.getInductionVar());
1641  for (Value args : innerWarp.getBody()->getArguments()) {
1642  argMapping.push_back(args);
1643  }
1644  argMapping.resize(forOp.getBody()->getNumArguments());
1645  SmallVector<Value> yieldOperands;
1646  for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1647  yieldOperands.push_back(operand);
1648  rewriter.eraseOp(forOp.getBody()->getTerminator());
1649  rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1650  rewriter.setInsertionPointToEnd(innerWarp.getBody());
1651  rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
1652  rewriter.setInsertionPointAfter(innerWarp);
1653  if (!innerWarp.getResults().empty())
1654  rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1655  rewriter.eraseOp(forOp);
1656  // Replace the warpOp result coming from the original ForOp.
1657  for (const auto &res : llvm::enumerate(resultIdx)) {
1658  rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1659  newForOp.getResult(res.index()));
1660  newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1661  }
1662  newForOp.walk([&](Operation *op) {
1663  for (OpOperand &operand : op->getOpOperands()) {
1664  auto it = argIndexMapping.find(operand.get());
1665  if (it == argIndexMapping.end())
1666  continue;
1667  operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1668  }
1669  });
1670 
1671  // Finally, hoist out any now uniform code from the inner warp op.
1672  mlir::vector::moveScalarUniformCode(innerWarp);
1673  return success();
1674  }
1675 
1676 private:
1677  DistributionMapFn distributionMapFn;
1678 };
1679 
1680 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1681 /// The vector is reduced in parallel. Currently limited to vector size
1682 /// matching the warpOp size. E.g.:
1683 /// ```
1684 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1685 /// %0 = "some_def"() : () -> (vector<32xf32>)
1686 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1687 /// gpu.yield %1 : f32
1688 /// }
1689 /// ```
1690 /// is lowered to:
1691 /// ```
1692 /// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1693 /// %1 = "some_def"() : () -> (vector<32xf32>)
1694 /// gpu.yield %1 : vector<32xf32>
1695 /// }
1696 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
1697 /// %r = ("warp.reduction %a")
1698 /// ```
1699 struct WarpOpReduction : public WarpDistributionPattern {
1700  WarpOpReduction(MLIRContext *context,
1701  DistributedReductionFn distributedReductionFn,
1702  PatternBenefit benefit = 1)
1703  : WarpDistributionPattern(context, benefit),
1704  distributedReductionFn(std::move(distributedReductionFn)) {}
1705 
1706  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1707  PatternRewriter &rewriter) const override {
1708  OpOperand *yieldOperand =
1709  getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1710  if (!yieldOperand)
1711  return failure();
1712 
1713  auto reductionOp =
1714  cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1715  auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1716  // Only rank 1 vectors supported.
1717  if (vectorType.getRank() != 1)
1718  return rewriter.notifyMatchFailure(
1719  warpOp, "Only rank 1 reductions can be distributed.");
1720  // Only warp_size-sized vectors supported.
1721  if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1722  return rewriter.notifyMatchFailure(
1723  warpOp, "Reduction vector dimension must match was size.");
1724  if (!reductionOp.getType().isIntOrFloat())
1725  return rewriter.notifyMatchFailure(
1726  warpOp, "Reduction distribution currently only supports floats and "
1727  "integer types.");
1728 
1729  int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1730  // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1731  unsigned operandIndex = yieldOperand->getOperandNumber();
1732  SmallVector<Value> yieldValues = {reductionOp.getVector()};
1733  SmallVector<Type> retTypes = {
1734  VectorType::get({numElements}, reductionOp.getType())};
1735  if (reductionOp.getAcc()) {
1736  yieldValues.push_back(reductionOp.getAcc());
1737  retTypes.push_back(reductionOp.getAcc().getType());
1738  }
1739  SmallVector<size_t> newRetIndices;
1740  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1741  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1742  rewriter.setInsertionPointAfter(newWarpOp);
1743 
1744  // Obtain data to reduce for a single lane.
1745  Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1746  // Distribute and reduce across threads.
1747  Value fullReduce =
1748  distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1749  reductionOp.getKind(), newWarpOp.getWarpSize());
1750  if (reductionOp.getAcc()) {
1751  fullReduce = vector::makeArithReduction(
1752  rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1753  newWarpOp.getResult(newRetIndices[1]));
1754  }
1755  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1756  return success();
1757  }
1758 
1759 private:
1760  DistributedReductionFn distributedReductionFn;
1761 };
1762 
1763 } // namespace
1764 
1768  patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
1769 }
1770 
1771 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1772  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1773  unsigned maxNumElementsToExtract, PatternBenefit benefit) {
1774  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
1775  maxNumElementsToExtract, benefit);
1776 }
1777 
1778 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1779  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1780  const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1781  PatternBenefit readBenefit) {
1782  patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1783  patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1784  WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1785  WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1786  WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1787  patterns.getContext(), benefit);
1788  patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
1789  benefit);
1790  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
1791  benefit);
1792 }
1793 
1794 void mlir::vector::populateDistributeReduction(
1796  const DistributedReductionFn &distributedReductionFn,
1797  PatternBenefit benefit) {
1798  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
1799  benefit);
1800 }
1801 
1802 /// Helper to know if an op can be hoisted out of the region.
1803 static bool canBeHoisted(Operation *op,
1804  function_ref<bool(Value)> definedOutside) {
1805  return llvm::all_of(op->getOperands(), definedOutside) &&
1806  isMemoryEffectFree(op) && op->getNumRegions() == 0;
1807 }
1808 
1809 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1810  Block *body = warpOp.getBody();
1811 
1812  // Keep track of the ops we want to hoist.
1813  llvm::SmallSetVector<Operation *, 8> opsToMove;
1814 
1815  // Helper to check if a value is or will be defined outside of the region.
1816  auto isDefinedOutsideOfBody = [&](Value value) {
1817  auto *definingOp = value.getDefiningOp();
1818  return (definingOp && opsToMove.count(definingOp)) ||
1819  warpOp.isDefinedOutsideOfRegion(value);
1820  };
1821 
1822  // Do not use walk here, as we do not want to go into nested regions and hoist
1823  // operations from there.
1824  for (auto &op : body->without_terminator()) {
1825  bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1826  return isa<VectorType>(result.getType());
1827  });
1828  if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1829  opsToMove.insert(&op);
1830  }
1831 
1832  // Move all the ops marked as uniform outside of the region.
1833  for (Operation *op : opsToMove)
1834  op->moveBefore(warpOp);
1835 }
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:968
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:345
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:370
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:53
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:551
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:578
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1395
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1175
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1225
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2664
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:346
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:719
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:43
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:631
This represents an operation in an abstracted form, suitable for use with the builder APIs.