1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
15 #include "mlir/IR/AffineExpr.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include <numeric>
21 #include <utility>
23 using namespace mlir;
24 using namespace mlir::vector;
26 /// Currently the distribution map is implicit based on the vector shape. In the
27 /// future it will be part of the op.
28 /// Example:
29 /// ```
30 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
31 /// ...
32 /// vector.yield %3 : vector<32x16x64xf32>
33 /// }
34 /// ```
35 /// Would have an implicit map of:
36 /// `(d0, d1, d2) -> (d0, d2)`
37 static AffineMap calculateImplicitMap(VectorType sequentialType,
38  VectorType distributedType) {
40  perm.reserve(1);
41  // Check which dimensions of the sequential type are different than the
42  // dimensions of the distributed type to know the distributed dimensions. Then
43  // associate each distributed dimension to an ID in order.
44  for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
45  if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
46  perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
47  }
48  auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
49  distributedType.getContext());
50  return map;
51 }
53 namespace {
55 /// Helper struct to create the load / store operations that permit transit
56 /// through the parallel / sequential and the sequential / parallel boundaries
57 /// when performing `rewriteWarpOpToScfFor`.
58 ///
59 /// The vector distribution dimension is inferred from the vector types.
60 struct DistributedLoadStoreHelper {
61  DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
62  Value laneId, Value zero)
63  : sequentialVal(sequentialVal), distributedVal(distributedVal),
64  laneId(laneId), zero(zero) {
65  sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
66  distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
67  if (sequentialVectorType && distributedVectorType)
68  distributionMap =
69  calculateImplicitMap(sequentialVectorType, distributedVectorType);
70  }
72  Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
73  int64_t distributedSize = distributedVectorType.getDimSize(index);
75  return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
76  ArrayRef<Value>{laneId});
77  }
79  /// Create a store during the process of distributing the
80  /// `vector.warp_execute_on_thread_0` op.
81  /// Vector distribution assumes the following convention regarding the
82  /// temporary buffers that are created to transition values. This **must**
83  /// be properly specified in the `options.warpAllocationFn`:
84  /// 1. scalars of type T transit through a memref<1xT>.
85  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
86  Operation *buildStore(RewriterBase &b, Location loc, Value val,
87  Value buffer) {
88  assert((val == distributedVal || val == sequentialVal) &&
89  "Must store either the preregistered distributed or the "
90  "preregistered sequential value.");
91  // Scalar case can directly use
92  if (!isa<VectorType>(val.getType()))
93  return b.create<memref::StoreOp>(loc, val, buffer, zero);
95  // Vector case must use vector::TransferWriteOp which will later lower to
96  // of depending on further lowerings.
97  int64_t rank = sequentialVectorType.getRank();
98  SmallVector<Value> indices(rank, zero);
99  if (val == distributedVal) {
100  for (auto dimExpr : distributionMap.getResults()) {
101  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
102  indices[index] = buildDistributedOffset(b, loc, index);
103  }
104  }
105  SmallVector<bool> inBounds(indices.size(), true);
106  return b.create<vector::TransferWriteOp>(
107  loc, val, buffer, indices,
108  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
109  }
111  /// Create a load during the process of distributing the
112  /// `vector.warp_execute_on_thread_0` op.
113  /// Vector distribution assumes the following convention regarding the
114  /// temporary buffers that are created to transition values. This **must**
115  /// be properly specified in the `options.warpAllocationFn`:
116  /// 1. scalars of type T transit through a memref<1xT>.
117  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
118  ///
119  /// When broadcastMode is true, the load is not distributed to account for
120  /// the broadcast semantics of the `vector.warp_execute_on_lane_0` op.
121  ///
122  /// Example:
123  ///
124  /// ```
125  /// %r = vector.warp_execute_on_lane_0(...) -> (f32) {
126  /// vector.yield %cst : f32
127  /// }
128  /// // Both types are f32. The constant %cst is broadcasted to all lanes.
129  /// ```
130  /// This behavior described in more detail in the documentation of the op.
131  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
133  // Scalar case can directly use
134  if (!isa<VectorType>(type))
135  return b.create<memref::LoadOp>(loc, buffer, zero);
137  // Other cases must be vector atm.
138  // Vector case must use vector::TransferReadOp which will later lower to
139  // of depending on further lowerings.
140  assert((type == distributedVectorType || type == sequentialVectorType) &&
141  "Must store either the preregistered distributed or the "
142  "preregistered sequential type.");
143  SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
144  if (type == distributedVectorType) {
145  for (auto dimExpr : distributionMap.getResults()) {
146  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
147  indices[index] = buildDistributedOffset(b, loc, index);
148  }
149  }
150  SmallVector<bool> inBounds(indices.size(), true);
151  return b.create<vector::TransferReadOp>(
152  loc, cast<VectorType>(type), buffer, indices,
153  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
154  }
156  Value sequentialVal, distributedVal, laneId, zero;
157  VectorType sequentialVectorType, distributedVectorType;
158  AffineMap distributionMap;
159 };
161 } // namespace
163 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
164 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
165  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
166  ValueRange newYieldedValues, TypeRange newReturnTypes) {
167  // Create a new op before the existing one, with the extra operands.
168  OpBuilder::InsertionGuard g(rewriter);
169  rewriter.setInsertionPoint(warpOp);
170  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
171  warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
172  warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
174  Region &opBody = warpOp.getBodyRegion();
175  Region &newOpBody = newWarpOp.getBodyRegion();
176  Block &newOpFirstBlock = newOpBody.front();
177  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
178  rewriter.eraseBlock(&newOpFirstBlock);
179  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
180  "expected WarpOp with single block");
182  auto yield =
183  cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
185  rewriter.modifyOpInPlace(
186  yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
187  return newWarpOp;
188 }
190 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
191 /// `indices` return the index of each new output.
192 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
193  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
194  ValueRange newYieldedValues, TypeRange newReturnTypes,
195  llvm::SmallVector<size_t> &indices) {
196  SmallVector<Type> types(warpOp.getResultTypes().begin(),
197  warpOp.getResultTypes().end());
198  auto yield = cast<vector::YieldOp>(
199  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
200  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
201  yield.getOperands().end());
202  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
203  if (yieldValues.insert(std::get<0>(newRet))) {
204  types.push_back(std::get<1>(newRet));
205  indices.push_back(yieldValues.size() - 1);
206  } else {
207  // If the value already exit the region don't create a new output.
208  for (auto [idx, yieldOperand] :
209  llvm::enumerate(yieldValues.getArrayRef())) {
210  if (yieldOperand == std::get<0>(newRet)) {
211  indices.push_back(idx);
212  break;
213  }
214  }
215  }
216  }
217  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
218  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
219  rewriter, warpOp, yieldValues.getArrayRef(), types);
220  rewriter.replaceOp(warpOp,
221  newWarpOp.getResults().take_front(warpOp.getNumResults()));
222  return newWarpOp;
223 }
225 /// Helper to know if an op can be hoisted out of the region.
226 static bool canBeHoisted(Operation *op,
227  function_ref<bool(Value)> definedOutside) {
228  return llvm::all_of(op->getOperands(), definedOutside) &&
229  isMemoryEffectFree(op) && op->getNumRegions() == 0;
230 }
232 /// Return a value yielded by `warpOp` which statifies the filter lamdba
233 /// condition and is not dead.
234 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
235  const std::function<bool(Operation *)> &fn) {
236  auto yield = cast<vector::YieldOp>(
237  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
238  for (OpOperand &yieldOperand : yield->getOpOperands()) {
239  Value yieldValues = yieldOperand.get();
240  Operation *definedOp = yieldValues.getDefiningOp();
241  if (definedOp && fn(definedOp)) {
242  if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
243  return &yieldOperand;
244  }
245  }
246  return {};
247 }
249 // Clones `op` into a new operation that takes `operands` and returns
250 // `resultTypes`.
252  Location loc, Operation *op,
253  ArrayRef<Value> operands,
254  ArrayRef<Type> resultTypes) {
255  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
256  op->getAttrs());
257  return rewriter.create(res);
258 }
260 namespace {
262 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
263 /// thread `laneId` executes the entirety of the computation.
264 ///
265 /// After the transformation:
266 /// - the IR within the scf.if op can be thought of as executing sequentially
267 /// (from the point of view of threads along `laneId`).
268 /// - the IR outside of the scf.if op can be thought of as executing in
269 /// parallel (from the point of view of threads along `laneId`).
270 ///
271 /// Values that need to transit through the parallel / sequential and the
272 /// sequential / parallel boundaries do so via reads and writes to a temporary
273 /// memory location.
274 ///
275 /// The transformation proceeds in multiple steps:
276 /// 1. Create the scf.if op.
277 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
278 /// within the scf.if to transit the values captured from above.
279 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
280 /// consistent within the scf.if.
281 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
282 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to
283 /// transit the values returned by the op.
284 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
285 /// consistent after the scf.if.
286 /// 7. Perform late cleanups.
287 ///
288 /// All this assumes the vector distribution occurs along the most minor
289 /// distributed vector dimension.
290 struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
291  WarpOpToScfIfPattern(MLIRContext *context,
293  PatternBenefit benefit = 1)
294  : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
295  options(options) {}
297  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
298  PatternRewriter &rewriter) const override {
299  assert(warpOp.getBodyRegion().hasOneBlock() &&
300  "expected WarpOp with single block");
301  Block *warpOpBody = &warpOp.getBodyRegion().front();
302  Location loc = warpOp.getLoc();
304  // Passed all checks. Start rewriting.
305  OpBuilder::InsertionGuard g(rewriter);
306  rewriter.setInsertionPoint(warpOp);
308  // Step 1: Create scf.if op.
309  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
310  Value isLane0 = rewriter.create<arith::CmpIOp>(
311  loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
312  auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
313  /*withElseRegion=*/false);
314  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
316  // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
317  // reads within the scf.if to transit the values captured from above.
318  SmallVector<Value> bbArgReplacements;
319  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
320  Value sequentialVal = warpOpBody->getArgument(it.index());
321  Value distributedVal = it.value();
322  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
323  warpOp.getLaneid(), c0);
325  // Create buffer before the ifOp.
326  rewriter.setInsertionPoint(ifOp);
327  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
328  sequentialVal.getType());
329  // Store distributed vector into buffer, before the ifOp.
330  helper.buildStore(rewriter, loc, distributedVal, buffer);
331  // Load sequential vector from buffer, inside the ifOp.
332  rewriter.setInsertionPointToStart(ifOp.thenBlock());
333  bbArgReplacements.push_back(
334  helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
335  }
337  // Step 3. Insert sync after all the stores and before all the loads.
338  if (!warpOp.getArgs().empty()) {
339  rewriter.setInsertionPoint(ifOp);
340  options.warpSyncronizationFn(loc, rewriter, warpOp);
341  }
343  // Step 4. Move body of warpOp to ifOp.
344  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
346  // Step 5. Insert appropriate writes within scf.if and reads after the
347  // scf.if to transit the values returned by the op.
348  // TODO: at this point, we can reuse the shared memory from previous
349  // buffers.
350  SmallVector<Value> replacements;
351  auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
352  Location yieldLoc = yieldOp.getLoc();
353  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
354  Value sequentialVal = it.value();
355  Value distributedVal = warpOp->getResult(it.index());
356  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
357  warpOp.getLaneid(), c0);
359  // Create buffer before the ifOp.
360  rewriter.setInsertionPoint(ifOp);
361  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
362  sequentialVal.getType());
364  // Store yielded value into buffer, inside the ifOp, before the
365  // terminator.
366  rewriter.setInsertionPoint(yieldOp);
367  helper.buildStore(rewriter, loc, sequentialVal, buffer);
369  // Load distributed value from buffer, after the warpOp.
370  rewriter.setInsertionPointAfter(ifOp);
371  // Result type and yielded value type are the same. This is a broadcast.
372  // E.g.:
373  // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
374  // vector.yield %cst : f32
375  // }
376  // Both types are f32. The constant %cst is broadcasted to all lanes.
377  // This is described in more detail in the documentation of the op.
378  replacements.push_back(
379  helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
380  }
382  // Step 6. Insert sync after all the stores and before all the loads.
383  if (!yieldOp.getOperands().empty()) {
384  rewriter.setInsertionPointAfter(ifOp);
385  options.warpSyncronizationFn(loc, rewriter, warpOp);
386  }
388  // Step 7. Delete terminator and add empty scf.yield.
389  rewriter.eraseOp(yieldOp);
390  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
391  rewriter.create<scf::YieldOp>(yieldLoc);
393  // Compute replacements for WarpOp results.
394  rewriter.replaceOp(warpOp, replacements);
396  return success();
397  }
399 private:
401 };
403 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
404 /// op with the proper return type.
405 /// The new write op is updated to write the result of the new warp execute op.
406 /// The old `writeOp` is deleted.
407 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
408  WarpExecuteOnLane0Op warpOp,
409  vector::TransferWriteOp writeOp,
410  VectorType targetType,
411  VectorType maybeMaskType) {
412  assert(writeOp->getParentOp() == warpOp &&
413  "write must be nested immediately under warp");
414  OpBuilder::InsertionGuard g(rewriter);
415  SmallVector<size_t> newRetIndices;
416  WarpExecuteOnLane0Op newWarpOp;
417  if (maybeMaskType) {
419  rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
420  TypeRange{targetType, maybeMaskType}, newRetIndices);
421  } else {
423  rewriter, warpOp, ValueRange{{writeOp.getVector()}},
424  TypeRange{targetType}, newRetIndices);
425  }
426  rewriter.setInsertionPointAfter(newWarpOp);
427  auto newWriteOp =
428  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
429  rewriter.eraseOp(writeOp);
430  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
431  if (maybeMaskType)
432  newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
433  return newWriteOp;
434 }
436 /// Return the distributed vector type based on the original type and the
437 /// distribution map. The map is expected to have a dimension equal to the
438 /// original type rank and should be a projection where the results are the
439 /// distributed dimensions. The number of results should be equal to the number
440 /// of warp sizes which is currently limited to 1.
441 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
442 /// and a warp size of 16 would distribute the second dimension (associated to
443 /// d1) and return vector<16x2x64>
444 static VectorType getDistributedType(VectorType originalType, AffineMap map,
445  int64_t warpSize) {
446  SmallVector<int64_t> targetShape(originalType.getShape());
447  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
448  unsigned position = map.getDimPosition(i);
449  if (targetShape[position] % warpSize != 0) {
450  if (warpSize % targetShape[position] != 0) {
451  return VectorType();
452  }
453  warpSize /= targetShape[position];
454  targetShape[position] = 1;
455  continue;
456  }
457  targetShape[position] = targetShape[position] / warpSize;
458  warpSize = 1;
459  break;
460  }
461  if (warpSize != 1) {
462  return VectorType();
463  }
464  VectorType targetType =
465  VectorType::get(targetShape, originalType.getElementType());
466  return targetType;
467 }
469 /// Distribute transfer_write ops based on the affine map returned by
470 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
471 /// will not be distributed (it should be less than the warp size).
472 ///
473 /// Example:
474 /// ```
475 /// %0 = vector.warp_execute_on_lane_0(%id){
476 /// ...
477 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
478 /// vector.yield
479 /// }
480 /// ```
481 /// To
482 /// ```
483 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
484 /// ...
485 /// vector.yield %v : vector<32xf32>
486 /// }
487 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
488 struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
489  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
490  unsigned maxNumElementsToExtract, PatternBenefit b = 1)
491  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
492  distributionMapFn(std::move(fn)),
493  maxNumElementsToExtract(maxNumElementsToExtract) {}
495  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
496  /// are multiples of the distribution ratio are supported at the moment.
497  LogicalResult tryDistributeOp(RewriterBase &rewriter,
498  vector::TransferWriteOp writeOp,
499  WarpExecuteOnLane0Op warpOp) const {
500  VectorType writtenVectorType = writeOp.getVectorType();
502  // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
503  // to separate it from the rest.
504  if (writtenVectorType.getRank() == 0)
505  return failure();
507  // 2. Compute the distributed type.
508  AffineMap map = distributionMapFn(writeOp.getVector());
509  VectorType targetType =
510  getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
511  if (!targetType)
512  return failure();
514  // 2.5 Compute the distributed type for the new mask;
515  VectorType maskType;
516  if (writeOp.getMask()) {
517  // TODO: Distribution of masked writes with non-trivial permutation maps
518  // requires the distribution of the mask to elementwise match the
519  // distribution of the permuted written vector. Currently the details
520  // of which lane is responsible for which element is captured strictly
521  // by shape information on the warp op, and thus requires materializing
522  // the permutation in IR.
523  if (!writeOp.getPermutationMap().isMinorIdentity())
524  return failure();
525  maskType =
526  getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
527  }
529  // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
530  // the rest.
531  vector::TransferWriteOp newWriteOp =
532  cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
534  // 4. Reindex the write using the distribution map.
535  auto newWarpOp =
536  newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
538  // Delinearize the lane id based on the way threads are divided across the
539  // vector. To get the number of threads per vector dimension, divide the
540  // sequential size by the distributed size along each dim.
541  rewriter.setInsertionPoint(newWriteOp);
542  SmallVector<OpFoldResult> delinearizedIdSizes;
543  for (auto [seqSize, distSize] :
544  llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
545  assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
546  delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
547  }
548  SmallVector<Value> delinearized;
549  if (map.getNumResults() > 1) {
550  delinearized = rewriter
551  .create<mlir::affine::AffineDelinearizeIndexOp>(
552  newWarpOp.getLoc(), newWarpOp.getLaneid(),
553  delinearizedIdSizes)
554  .getResults();
555  } else {
556  // If there is only one map result, we can elide the delinearization
557  // op and use the lane id directly.
558  delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
559  }
561  AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
562  Location loc = newWriteOp.getLoc();
563  SmallVector<Value> indices(newWriteOp.getIndices().begin(),
564  newWriteOp.getIndices().end());
565  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
566  AffineExpr d0, d1;
567  bindDims(newWarpOp.getContext(), d0, d1);
568  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
569  if (!indexExpr)
570  continue;
571  unsigned indexPos = indexExpr.getPosition();
572  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
573  Value laneId = delinearized[vectorPos];
574  auto scale =
575  rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
576  indices[indexPos] = affine::makeComposedAffineApply(
577  rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
578  }
579  newWriteOp.getIndicesMutable().assign(indices);
581  return success();
582  }
584  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
585  LogicalResult tryExtractOp(RewriterBase &rewriter,
586  vector::TransferWriteOp writeOp,
587  WarpExecuteOnLane0Op warpOp) const {
588  Location loc = writeOp.getLoc();
589  VectorType vecType = writeOp.getVectorType();
591  if (vecType.getNumElements() > maxNumElementsToExtract) {
592  return rewriter.notifyMatchFailure(
593  warpOp,
594  llvm::formatv(
595  "writes more elements ({0}) than allowed to extract ({1})",
596  vecType.getNumElements(), maxNumElementsToExtract));
597  }
599  // Do not process warp ops that contain only TransferWriteOps.
600  if (llvm::all_of(warpOp.getOps(),
601  llvm::IsaPred<vector::TransferWriteOp, vector::YieldOp>))
602  return failure();
604  SmallVector<Value> yieldValues = {writeOp.getVector()};
605  SmallVector<Type> retTypes = {vecType};
606  SmallVector<size_t> newRetIndices;
607  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
608  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
609  rewriter.setInsertionPointAfter(newWarpOp);
611  // Create a second warp op that contains only writeOp.
612  auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
613  loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
614  Block &body = secondWarpOp.getBodyRegion().front();
615  rewriter.setInsertionPointToStart(&body);
616  auto newWriteOp =
617  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
618  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
619  rewriter.eraseOp(writeOp);
620  rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
621  return success();
622  }
624  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
625  PatternRewriter &rewriter) const override {
626  auto yield = cast<vector::YieldOp>(
627  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
628  Operation *lastNode = yield->getPrevNode();
629  auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
630  if (!writeOp)
631  return failure();
633  Value maybeMask = writeOp.getMask();
634  if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
635  return writeOp.getVector() == value ||
636  (maybeMask && maybeMask == value) ||
637  warpOp.isDefinedOutsideOfRegion(value);
638  }))
639  return failure();
641  if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
642  return success();
644  // Masked writes not supported for extraction.
645  if (writeOp.getMask())
646  return failure();
648  if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
649  return success();
651  return failure();
652  }
654 private:
655  DistributionMapFn distributionMapFn;
656  unsigned maxNumElementsToExtract = 1;
657 };
659 /// Sink out elementwise op feeding into a warp op yield.
660 /// ```
661 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
662 /// ...
663 /// %3 = arith.addf %1, %2 : vector<32xf32>
664 /// vector.yield %3 : vector<32xf32>
665 /// }
666 /// ```
667 /// To
668 /// ```
669 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
670 /// vector<1xf32>, vector<1xf32>) {
671 /// ...
672 /// %4 = arith.addf %2, %3 : vector<32xf32>
673 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
674 /// vector<32xf32>
675 /// }
676 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
677 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
679  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
680  PatternRewriter &rewriter) const override {
681  OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
683  });
684  if (!yieldOperand)
685  return failure();
687  Operation *elementWise = yieldOperand->get().getDefiningOp();
688  unsigned operandIndex = yieldOperand->getOperandNumber();
689  Value distributedVal = warpOp.getResult(operandIndex);
690  SmallVector<Value> yieldValues;
691  SmallVector<Type> retTypes;
692  Location loc = warpOp.getLoc();
693  for (OpOperand &operand : elementWise->getOpOperands()) {
694  Type targetType;
695  if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
696  // If the result type is a vector, the operands must also be vectors.
697  auto operandType = cast<VectorType>(operand.get().getType());
698  targetType =
699  VectorType::get(vecType.getShape(), operandType.getElementType());
700  } else {
701  auto operandType = operand.get().getType();
702  assert(!isa<VectorType>(operandType) &&
703  "unexpected yield of vector from op with scalar result type");
704  targetType = operandType;
705  }
706  retTypes.push_back(targetType);
707  yieldValues.push_back(operand.get());
708  }
709  SmallVector<size_t> newRetIndices;
710  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
711  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
712  rewriter.setInsertionPointAfter(newWarpOp);
713  SmallVector<Value> newOperands(elementWise->getOperands().begin(),
714  elementWise->getOperands().end());
715  for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
716  newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
717  }
718  OpBuilder::InsertionGuard g(rewriter);
719  rewriter.setInsertionPointAfter(newWarpOp);
721  rewriter, loc, elementWise, newOperands,
722  {newWarpOp.getResult(operandIndex).getType()});
723  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
724  newOp->getResult(0));
725  return success();
726  }
727 };
729 /// Sink out splat constant op feeding into a warp op yield.
730 /// ```
731 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
732 /// ...
733 /// %cst = arith.constant dense<2.0> : vector<32xf32>
734 /// vector.yield %cst : vector<32xf32>
735 /// }
736 /// ```
737 /// To
738 /// ```
739 /// vector.warp_execute_on_lane_0(%arg0 {
740 /// ...
741 /// }
742 /// %0 = arith.constant dense<2.0> : vector<1xf32>
743 struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
745  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
746  PatternRewriter &rewriter) const override {
747  OpOperand *yieldOperand =
748  getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
749  if (!yieldOperand)
750  return failure();
751  auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
752  auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
753  if (!dense)
754  return failure();
755  // Notify the rewriter that the warp op is changing (see the comment on
756  // the WarpOpTransferRead pattern).
757  rewriter.startOpModification(warpOp);
758  unsigned operandIndex = yieldOperand->getOperandNumber();
759  Attribute scalarAttr = dense.getSplatValue<Attribute>();
760  auto newAttr = DenseElementsAttr::get(
761  cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
762  Location loc = warpOp.getLoc();
763  rewriter.setInsertionPointAfter(warpOp);
764  Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
765  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
766  rewriter.finalizeOpModification(warpOp);
767  return success();
768  }
769 };
771 /// Delinearize the given `laneId` into multiple dimensions, where each
772 /// dimension's size is determined by `originalShape` and `distributedShape`
773 /// together. This function expects the total numbers of threads needed for
774 /// distribution is equal to `warpSize`. Returns true and updates
775 /// `delinearizedIds` if so.
776 bool delinearizeLaneId(OpBuilder &builder, Location loc,
777  ArrayRef<int64_t> originalShape,
778  ArrayRef<int64_t> distributedShape, int64_t warpSize,
779  Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
780  // If the original shape and the distributed shape is the same, we don't
781  // distribute at all--every thread is handling the whole. For such case, we
782  // should not rely on lane IDs later. So just return an empty lane ID vector.
783  if (originalShape == distributedShape) {
784  delinearizedIds.clear();
785  return true;
786  }
788  SmallVector<int64_t> sizes;
789  for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
790  if (large % small != 0)
791  return false;
792  sizes.push_back(large / small);
793  }
794  if (std::accumulate(sizes.begin(), sizes.end(), 1,
795  std::multiplies<int64_t>()) != warpSize)
796  return false;
798  AffineExpr s0, s1;
799  bindSymbols(builder.getContext(), s0, s1);
801  int64_t usedThreads = 1;
803  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
804  delinearizedIds.assign(sizes.size(), zero);
806  for (int i = sizes.size() - 1; i >= 0; --i) {
807  usedThreads *= sizes[i];
808  if (usedThreads == warpSize) {
809  // We've used up all available threads. Don't need to perform modulo
810  // anymore. And we can stop the calculation for further dimensions.
811  delinearizedIds[i] = laneId;
812  break;
813  }
814  delinearizedIds[i] =
815  affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
817  builder, loc, s0.floorDiv(usedThreads), {laneId});
818  }
819  return true;
820 }
822 /// Sink out transfer_read op feeding into a warp op yield.
823 /// ```
824 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
825 /// ...
826 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
827 // vector<32xf32>
828 /// vector.yield %2 : vector<32xf32>
829 /// }
830 /// ```
831 /// To
832 /// ```
833 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
834 /// vector<1xf32>, vector<1xf32>) {
835 /// ...
836 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
837 /// vector<32xf32> vector.yield %2 : vector<32xf32>
838 /// }
839 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
840 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
842  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
843  PatternRewriter &rewriter) const override {
844  // Try to find a distributable yielded read. Note that this pattern can
845  // still fail at the end after distribution, in which case this might have
846  // missed another distributable read.
847  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
848  // Don't duplicate transfer_read ops when distributing.
849  return isa<vector::TransferReadOp>(op) && op->hasOneUse();
850  });
851  if (!operand)
852  return rewriter.notifyMatchFailure(
853  warpOp, "warp result is not a vector.transfer_read op");
854  auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
856  // Source must be defined outside of the region.
857  if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
858  return rewriter.notifyMatchFailure(
859  read, "source must be defined outside of the region");
861  unsigned operandIndex = operand->getOperandNumber();
862  Value distributedVal = warpOp.getResult(operandIndex);
864  SmallVector<Value, 4> indices(read.getIndices().begin(),
865  read.getIndices().end());
866  auto sequentialType = cast<VectorType>(read.getResult().getType());
867  auto distributedType = cast<VectorType>(distributedVal.getType());
868  AffineMap map = calculateImplicitMap(sequentialType, distributedType);
869  AffineMap indexMap = map.compose(read.getPermutationMap());
871  // Try to delinearize the lane ID to match the rank expected for
872  // distribution.
873  SmallVector<Value> delinearizedIds;
874  if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
875  distributedType.getShape(), warpOp.getWarpSize(),
876  warpOp.getLaneid(), delinearizedIds)) {
877  return rewriter.notifyMatchFailure(
878  read, "cannot delinearize lane ID for distribution");
879  }
880  assert(!delinearizedIds.empty() || map.getNumResults() == 0);
882  // Distribute indices and the mask (if present).
883  OpBuilder::InsertionGuard g(rewriter);
884  SmallVector<Value> additionalResults(indices.begin(), indices.end());
885  SmallVector<Type> additionalResultTypes(indices.size(),
886  rewriter.getIndexType());
887  additionalResults.push_back(read.getPadding());
888  additionalResultTypes.push_back(read.getPadding().getType());
890  bool hasMask = false;
891  if (read.getMask()) {
892  hasMask = true;
893  // TODO: Distribution of masked reads with non-trivial permutation maps
894  // requires the distribution of the mask to elementwise match the
895  // distribution of the permuted written vector. Currently the details
896  // of which lane is responsible for which element is captured strictly
897  // by shape information on the warp op, and thus requires materializing
898  // the permutation in IR.
899  if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
900  return rewriter.notifyMatchFailure(
901  read, "non-trivial permutation maps not supported");
902  VectorType maskType =
903  getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
904  additionalResults.push_back(read.getMask());
905  additionalResultTypes.push_back(maskType);
906  }
908  SmallVector<size_t> newRetIndices;
909  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
910  rewriter, warpOp, additionalResults, additionalResultTypes,
911  newRetIndices);
912  distributedVal = newWarpOp.getResult(operandIndex);
914  // Distributed indices were appended first.
915  SmallVector<Value> newIndices;
916  for (int64_t i = 0, e = indices.size(); i < e; ++i)
917  newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
919  rewriter.setInsertionPointAfter(newWarpOp);
920  for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
921  AffineExpr d0, d1;
922  bindDims(read.getContext(), d0, d1);
923  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
924  if (!indexExpr)
925  continue;
926  unsigned indexPos = indexExpr.getPosition();
927  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
928  int64_t scale = distributedType.getDimSize(vectorPos);
929  newIndices[indexPos] = affine::makeComposedAffineApply(
930  rewriter, read.getLoc(), d0 + scale * d1,
931  {newIndices[indexPos], delinearizedIds[vectorPos]});
932  }
934  // Distributed padding value was appended right after the indices.
935  Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
936  // Distributed mask value was added at the end (if the op has a mask).
937  Value newMask =
938  hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
939  : Value();
940  auto newRead = rewriter.create<vector::TransferReadOp>(
941  read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
942  read.getPermutationMapAttr(), newPadding, newMask,
943  read.getInBoundsAttr());
945  rewriter.replaceAllUsesWith(distributedVal, newRead);
946  return success();
947  }
948 };
950 /// Remove any result that has no use along with the matching yieldOp operand.
951 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
952 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
954  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
955  PatternRewriter &rewriter) const override {
956  SmallVector<Type> newResultTypes;
957  newResultTypes.reserve(warpOp->getNumResults());
958  SmallVector<Value> newYieldValues;
959  newYieldValues.reserve(warpOp->getNumResults());
960  DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
961  DenseMap<OpResult, int64_t> dedupResultPositionMap;
962  auto yield = cast<vector::YieldOp>(
963  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
965  // Some values may be yielded multiple times and correspond to multiple
966  // results. Deduplicating occurs by taking each result with its matching
967  // yielded value, and:
968  // 1. recording the unique first position at which the value is yielded.
969  // 2. recording for the result, the first position at which the dedup'ed
970  // value is yielded.
971  // 3. skipping from the new result types / new yielded values any result
972  // that has no use or whose yielded value has already been seen.
973  for (OpResult result : warpOp.getResults()) {
974  Value yieldOperand = yield.getOperand(result.getResultNumber());
975  auto it = dedupYieldOperandPositionMap.insert(
976  std::make_pair(yieldOperand, newResultTypes.size()));
977  dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
978  if (result.use_empty() || !it.second)
979  continue;
980  newResultTypes.push_back(result.getType());
981  newYieldValues.push_back(yieldOperand);
982  }
983  // No modification, exit early.
984  if (yield.getNumOperands() == newYieldValues.size())
985  return failure();
986  // Move the body of the old warpOp to a new warpOp.
987  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
988  rewriter, warpOp, newYieldValues, newResultTypes);
990  // Simplify the new warp op after dropping dead results.
991  newWarpOp.getBody()->walk([&](Operation *op) {
992  if (isOpTriviallyDead(op))
993  rewriter.eraseOp(op);
994  });
996  // Replace results of the old warpOp by the new, deduplicated results.
997  SmallVector<Value> newValues;
998  newValues.reserve(warpOp->getNumResults());
999  for (OpResult result : warpOp.getResults()) {
1000  if (result.use_empty())
1001  newValues.push_back(Value());
1002  else
1003  newValues.push_back(
1004  newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
1005  }
1006  rewriter.replaceOp(warpOp, newValues);
1007  return success();
1008  }
1009 };
1011 // If an operand is directly yielded out of the region we can forward it
1012 // directly and it doesn't need to go through the region.
1013 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
1015  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1016  PatternRewriter &rewriter) const override {
1017  SmallVector<Type> resultTypes;
1018  SmallVector<Value> yieldValues;
1019  auto yield = cast<vector::YieldOp>(
1020  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1021  Value valForwarded;
1022  unsigned resultIndex;
1023  for (OpOperand &operand : yield->getOpOperands()) {
1024  Value result = warpOp.getResult(operand.getOperandNumber());
1025  if (result.use_empty())
1026  continue;
1028  // Assume all the values coming from above are uniform.
1029  if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
1030  if (result.getType() != operand.get().getType())
1031  continue;
1032  valForwarded = operand.get();
1033  resultIndex = operand.getOperandNumber();
1034  break;
1035  }
1036  auto arg = dyn_cast<BlockArgument>(operand.get());
1037  if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1038  continue;
1039  Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1040  if (result.getType() != warpOperand.getType())
1041  continue;
1042  valForwarded = warpOperand;
1043  resultIndex = operand.getOperandNumber();
1044  break;
1045  }
1046  if (!valForwarded)
1047  return failure();
1048  // Notify the rewriter that the warp op is changing (see the comment on
1049  // the WarpOpTransferRead pattern).
1050  rewriter.startOpModification(warpOp);
1051  rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
1052  rewriter.finalizeOpModification(warpOp);
1053  return success();
1054  }
1055 };
1057 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1059  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1060  PatternRewriter &rewriter) const override {
1061  OpOperand *operand =
1062  getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1063  if (!operand)
1064  return failure();
1065  unsigned int operandNumber = operand->getOperandNumber();
1066  auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
1067  Location loc = broadcastOp.getLoc();
1068  auto destVecType =
1069  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1070  Value broadcastSrc = broadcastOp.getSource();
1071  Type broadcastSrcType = broadcastSrc.getType();
1073  // Check that the broadcast actually spans a set of values uniformly across
1074  // all threads. In other words, check that each thread can reconstruct
1075  // their own broadcast.
1076  // For that we simply check that the broadcast we want to build makes sense.
1077  if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
1079  return failure();
1080  SmallVector<size_t> newRetIndices;
1081  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1082  rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1083  rewriter.setInsertionPointAfter(newWarpOp);
1084  Value broadcasted = rewriter.create<vector::BroadcastOp>(
1085  loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1086  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1087  broadcasted);
1088  return success();
1089  }
1090 };
1092 /// Pattern to move shape cast out of the warp op. shape cast is basically a
1093 /// no-op for warp distribution; we need to handle the shape though.
1094 struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1096  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1097  PatternRewriter &rewriter) const override {
1098  OpOperand *operand =
1099  getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1100  if (!operand)
1101  return failure();
1103  auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
1105  unsigned int operandNumber = operand->getOperandNumber();
1106  auto castDistributedType =
1107  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1108  VectorType castOriginalType = oldCastOp.getSourceVectorType();
1109  VectorType castResultType = castDistributedType;
1111  // We expect the distributed type to have a smaller rank than the original
1112  // type. Prepend with size-one dimensions to make them the same.
1113  unsigned castDistributedRank = castDistributedType.getRank();
1114  unsigned castOriginalRank = castOriginalType.getRank();
1115  if (castDistributedRank < castOriginalRank) {
1116  SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1117  llvm::append_range(shape, castDistributedType.getShape());
1118  castDistributedType =
1119  VectorType::get(shape, castDistributedType.getElementType());
1120  }
1122  SmallVector<size_t> newRetIndices;
1123  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1124  rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1125  newRetIndices);
1126  rewriter.setInsertionPointAfter(newWarpOp);
1127  Value newCast = rewriter.create<vector::ShapeCastOp>(
1128  oldCastOp.getLoc(), castResultType,
1129  newWarpOp->getResult(newRetIndices[0]));
1130  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1131  return success();
1132  }
1133 };
1135 /// Sink out vector.create_mask op feeding into a warp op yield.
1136 /// ```
1137 /// %0 = ...
1138 /// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1139 /// ...
1140 /// %mask = vector.create_mask %0 : vector<32xi1>
1141 /// vector.yield %mask : vector<32xi1>
1142 /// }
1143 /// ```
1144 /// To
1145 /// ```
1146 /// %0 = ...
1147 /// vector.warp_execute_on_lane_0(%arg0) {
1148 /// ...
1149 /// }
1150 /// %cmp = arith.cmpi ult, %laneid, %0
1151 /// %ub = %cmp, %c0, %c1
1152 /// %1 = vector.create_mask %ub : vector<1xi1>
1153 struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
1155  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1156  PatternRewriter &rewriter) const override {
1157  OpOperand *yieldOperand =
1158  getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1159  if (!yieldOperand)
1160  return failure();
1162  auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1164  // Early exit if any values needed for calculating the new mask indices
1165  // are defined inside the warp op.
1166  if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1167  return warpOp.isDefinedOutsideOfRegion(value);
1168  }))
1169  return failure();
1171  Location loc = mask.getLoc();
1172  unsigned operandIndex = yieldOperand->getOperandNumber();
1174  auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1175  VectorType seqType = mask.getVectorType();
1176  ArrayRef<int64_t> seqShape = seqType.getShape();
1177  ArrayRef<int64_t> distShape = distType.getShape();
1179  rewriter.setInsertionPointAfter(warpOp);
1181  // Delinearize the lane ID for constructing the distributed mask sizes.
1182  SmallVector<Value> delinearizedIds;
1183  if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1184  warpOp.getWarpSize(), warpOp.getLaneid(),
1185  delinearizedIds))
1186  return rewriter.notifyMatchFailure(
1187  mask, "cannot delinearize lane ID for distribution");
1188  assert(!delinearizedIds.empty());
1190  // Notify the rewriter that the warp op is changing (see the comment on
1191  // the WarpOpTransferRead pattern).
1192  rewriter.startOpModification(warpOp);
1194  AffineExpr s0, s1;
1195  bindSymbols(rewriter.getContext(), s0, s1);
1196  SmallVector<Value> newOperands;
1197  for (int i = 0, e = distShape.size(); i < e; ++i) {
1198  // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1199  // find the distance from the largest mask index owned by this lane to the
1200  // original mask size. `vector.create_mask` implicitly clamps mask
1201  // operands to the range [0, mask_vector_size[i]], or in other words, the
1202  // mask sizes are always in the range [0, mask_vector_size[i]).
1204  rewriter, loc, s1 - s0 * distShape[i],
1205  {delinearizedIds[i], mask.getOperand(i)});
1206  newOperands.push_back(maskDimIdx);
1207  }
1209  auto newMask =
1210  rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1211  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1212  rewriter.finalizeOpModification(warpOp);
1213  return success();
1214  }
1215 };
1217 /// Pattern to move out vector.extract of single element vector. Those don't
1218 /// need to be distributed and can just be propagated outside of the region.
1219 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
1221  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1222  PatternRewriter &rewriter) const override {
1223  OpOperand *operand =
1224  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1225  if (!operand)
1226  return failure();
1227  unsigned int operandNumber = operand->getOperandNumber();
1228  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1229  VectorType extractSrcType = extractOp.getSourceVectorType();
1230  Location loc = extractOp.getLoc();
1232  // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
1233  assert(extractSrcType.getRank() > 0 &&
1234  "vector.extract does not support rank 0 sources");
1236  // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
1237  // canonicalized to %v.
1238  if (extractOp.getNumIndices() == 0)
1239  return failure();
1241  // Rewrite vector.extract with 1d source to vector.extractelement.
1242  if (extractSrcType.getRank() == 1) {
1243  if (extractOp.hasDynamicPosition())
1244  // TODO: Dinamic position not supported yet.
1245  return failure();
1247  assert(extractOp.getNumIndices() == 1 && "expected 1 index");
1248  int64_t pos = extractOp.getStaticPosition()[0];
1249  rewriter.setInsertionPoint(extractOp);
1250  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
1251  extractOp, extractOp.getVector(),
1252  rewriter.create<arith::ConstantIndexOp>(loc, pos));
1253  return success();
1254  }
1256  // All following cases are 2d or higher dimensional source vectors.
1258  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1259  // There is no distribution, this is a broadcast. Simply move the extract
1260  // out of the warp op.
1261  // TODO: This could be optimized. E.g., in case of a scalar result, let
1262  // one lane extract and shuffle the result to all other lanes (same as
1263  // the 1d case).
1264  SmallVector<size_t> newRetIndices;
1265  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1266  rewriter, warpOp, {extractOp.getVector()},
1267  {extractOp.getSourceVectorType()}, newRetIndices);
1268  rewriter.setInsertionPointAfter(newWarpOp);
1269  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1270  // Extract from distributed vector.
1271  Value newExtract = rewriter.create<vector::ExtractOp>(
1272  loc, distributedVec, extractOp.getMixedPosition());
1273  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1274  newExtract);
1275  return success();
1276  }
1278  // Find the distributed dimension. There should be exactly one.
1279  auto distributedType =
1280  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1281  auto yieldedType = cast<VectorType>(operand->get().getType());
1282  int64_t distributedDim = -1;
1283  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1284  if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1285  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1286  // support distributing multiple dimensions in the future.
1287  assert(distributedDim == -1 && "found multiple distributed dims");
1288  distributedDim = i;
1289  }
1290  }
1291  assert(distributedDim != -1 && "could not find distributed dimension");
1292  (void)distributedDim;
1294  // Yield source vector from warp op.
1295  SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1296  for (int i = 0; i < distributedType.getRank(); ++i)
1297  newDistributedShape[i + extractOp.getNumIndices()] =
1298  distributedType.getDimSize(i);
1299  auto newDistributedType =
1300  VectorType::get(newDistributedShape, distributedType.getElementType());
1301  SmallVector<size_t> newRetIndices;
1302  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1303  rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1304  newRetIndices);
1305  rewriter.setInsertionPointAfter(newWarpOp);
1306  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1307  // Extract from distributed vector.
1308  Value newExtract = rewriter.create<vector::ExtractOp>(
1309  loc, distributedVec, extractOp.getMixedPosition());
1310  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1311  newExtract);
1312  return success();
1313  }
1314 };
1316 /// Pattern to move out vector.extractelement of 0-D tensors. Those don't
1317 /// need to be distributed and can just be propagated outside of the region.
1318 struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1319  WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1320  PatternBenefit b = 1)
1321  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1322  warpShuffleFromIdxFn(std::move(fn)) {}
1323  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1324  PatternRewriter &rewriter) const override {
1325  OpOperand *operand =
1326  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1327  if (!operand)
1328  return failure();
1329  unsigned int operandNumber = operand->getOperandNumber();
1330  auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1331  VectorType extractSrcType = extractOp.getSourceVectorType();
1332  // TODO: Supported shuffle types should be parameterizable, similar to
1333  // `WarpShuffleFromIdxFn`.
1334  if (!extractSrcType.getElementType().isF32() &&
1335  !extractSrcType.getElementType().isInteger(32))
1336  return rewriter.notifyMatchFailure(
1337  extractOp, "only f32/i32 element types are supported");
1338  bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1339  Type elType = extractSrcType.getElementType();
1340  VectorType distributedVecType;
1341  if (!is0dOrVec1Extract) {
1342  assert(extractSrcType.getRank() == 1 &&
1343  "expected that extractelement src rank is 0 or 1");
1344  if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1345  return failure();
1346  int64_t elementsPerLane =
1347  extractSrcType.getShape()[0] / warpOp.getWarpSize();
1348  distributedVecType = VectorType::get({elementsPerLane}, elType);
1349  } else {
1350  distributedVecType = extractSrcType;
1351  }
1352  // Yield source vector and position (if present) from warp op.
1353  SmallVector<Value> additionalResults{extractOp.getVector()};
1354  SmallVector<Type> additionalResultTypes{distributedVecType};
1355  if (static_cast<bool>(extractOp.getPosition())) {
1356  additionalResults.push_back(extractOp.getPosition());
1357  additionalResultTypes.push_back(extractOp.getPosition().getType());
1358  }
1359  Location loc = extractOp.getLoc();
1360  SmallVector<size_t> newRetIndices;
1361  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1362  rewriter, warpOp, additionalResults, additionalResultTypes,
1363  newRetIndices);
1364  rewriter.setInsertionPointAfter(newWarpOp);
1365  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1367  // 0d extract: The new warp op broadcasts the source vector to all lanes.
1368  // All lanes extract the scalar.
1369  if (is0dOrVec1Extract) {
1370  Value newExtract;
1371  if (extractSrcType.getRank() == 1) {
1372  newExtract = rewriter.create<vector::ExtractElementOp>(
1373  loc, distributedVec,
1374  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1376  } else {
1377  newExtract =
1378  rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
1379  }
1380  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1381  newExtract);
1382  return success();
1383  }
1385  // 1d extract: Distribute the source vector. One lane extracts and shuffles
1386  // the value to all other lanes.
1387  int64_t elementsPerLane = distributedVecType.getShape()[0];
1388  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1389  // tid of extracting thread: pos / elementsPerLane
1390  Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
1391  loc, sym0.ceilDiv(elementsPerLane),
1392  newWarpOp->getResult(newRetIndices[1]));
1393  // Extract at position: pos % elementsPerLane
1394  Value pos =
1395  elementsPerLane == 1
1396  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1397  : rewriter
1398  .create<affine::AffineApplyOp>(
1399  loc, sym0 % elementsPerLane,
1400  newWarpOp->getResult(newRetIndices[1]))
1401  .getResult();
1402  Value extracted =
1403  rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
1405  // Shuffle the extracted value to all lanes.
1406  Value shuffled = warpShuffleFromIdxFn(
1407  loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1408  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1409  return success();
1410  }
1412 private:
1413  WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1414 };
1416 struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1419  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1420  PatternRewriter &rewriter) const override {
1421  OpOperand *operand =
1422  getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1423  if (!operand)
1424  return failure();
1425  unsigned int operandNumber = operand->getOperandNumber();
1426  auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1427  VectorType vecType = insertOp.getDestVectorType();
1428  VectorType distrType =
1429  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1430  bool hasPos = static_cast<bool>(insertOp.getPosition());
1432  // Yield destination vector, source scalar and position from warp op.
1433  SmallVector<Value> additionalResults{insertOp.getDest(),
1434  insertOp.getSource()};
1435  SmallVector<Type> additionalResultTypes{distrType,
1436  insertOp.getSource().getType()};
1437  if (hasPos) {
1438  additionalResults.push_back(insertOp.getPosition());
1439  additionalResultTypes.push_back(insertOp.getPosition().getType());
1440  }
1441  Location loc = insertOp.getLoc();
1442  SmallVector<size_t> newRetIndices;
1443  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1444  rewriter, warpOp, additionalResults, additionalResultTypes,
1445  newRetIndices);
1446  rewriter.setInsertionPointAfter(newWarpOp);
1447  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1448  Value newSource = newWarpOp->getResult(newRetIndices[1]);
1449  Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
1450  rewriter.setInsertionPointAfter(newWarpOp);
1452  if (vecType == distrType) {
1453  // Broadcast: Simply move the vector.inserelement op out.
1454  Value newInsert = rewriter.create<vector::InsertElementOp>(
1455  loc, newSource, distributedVec, newPos);
1456  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1457  newInsert);
1458  return success();
1459  }
1461  // This is a distribution. Only one lane should insert.
1462  int64_t elementsPerLane = distrType.getShape()[0];
1463  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1464  // tid of extracting thread: pos / elementsPerLane
1465  Value insertingLane = rewriter.create<affine::AffineApplyOp>(
1466  loc, sym0.ceilDiv(elementsPerLane), newPos);
1467  // Insert position: pos % elementsPerLane
1468  Value pos =
1469  elementsPerLane == 1
1470  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1471  : rewriter
1472  .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1473  newPos)
1474  .getResult();
1475  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1476  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1477  Value newResult =
1478  rewriter
1479  .create<scf::IfOp>(
1480  loc, isInsertingLane,
1481  /*thenBuilder=*/
1482  [&](OpBuilder &builder, Location loc) {
1483  Value newInsert = builder.create<vector::InsertElementOp>(
1484  loc, newSource, distributedVec, pos);
1485  builder.create<scf::YieldOp>(loc, newInsert);
1486  },
1487  /*elseBuilder=*/
1488  [&](OpBuilder &builder, Location loc) {
1489  builder.create<scf::YieldOp>(loc, distributedVec);
1490  })
1491  .getResult(0);
1492  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1493  return success();
1494  }
1495 };
1497 struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
1500  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1501  PatternRewriter &rewriter) const override {
1502  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1503  if (!operand)
1504  return failure();
1505  unsigned int operandNumber = operand->getOperandNumber();
1506  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1507  Location loc = insertOp.getLoc();
1509  // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
1510  if (insertOp.getNumIndices() == 0)
1511  return failure();
1513  // Rewrite vector.insert with 1d dest to vector.insertelement.
1514  if (insertOp.getDestVectorType().getRank() == 1) {
1515  if (insertOp.hasDynamicPosition())
1516  // TODO: Dinamic position not supported yet.
1517  return failure();
1519  assert(insertOp.getNumIndices() == 1 && "expected 1 index");
1520  int64_t pos = insertOp.getStaticPosition()[0];
1521  rewriter.setInsertionPoint(insertOp);
1522  rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
1523  insertOp, insertOp.getSource(), insertOp.getDest(),
1524  rewriter.create<arith::ConstantIndexOp>(loc, pos));
1525  return success();
1526  }
1528  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1529  // There is no distribution, this is a broadcast. Simply move the insert
1530  // out of the warp op.
1531  SmallVector<size_t> newRetIndices;
1532  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1533  rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1534  {insertOp.getSourceType(), insertOp.getDestVectorType()},
1535  newRetIndices);
1536  rewriter.setInsertionPointAfter(newWarpOp);
1537  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1538  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1539  Value newResult = rewriter.create<vector::InsertOp>(
1540  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1541  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1542  newResult);
1543  return success();
1544  }
1546  // Find the distributed dimension. There should be exactly one.
1547  auto distrDestType =
1548  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1549  auto yieldedType = cast<VectorType>(operand->get().getType());
1550  int64_t distrDestDim = -1;
1551  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1552  if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1553  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1554  // support distributing multiple dimensions in the future.
1555  assert(distrDestDim == -1 && "found multiple distributed dims");
1556  distrDestDim = i;
1557  }
1558  }
1559  assert(distrDestDim != -1 && "could not find distributed dimension");
1561  // Compute the distributed source vector type.
1562  VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
1563  SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1564  // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1565  // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1566  // insert a smaller vector<3xf32>.
1567  // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1568  // case, one lane will insert the source vector<96xf32>. The other
1569  // lanes will not do anything.
1570  int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1571  if (distrSrcDim >= 0)
1572  distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1573  auto distrSrcType =
1574  VectorType::get(distrSrcShape, distrDestType.getElementType());
1576  // Yield source and dest vectors from warp op.
1577  SmallVector<size_t> newRetIndices;
1578  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1579  rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1580  {distrSrcType, distrDestType}, newRetIndices);
1581  rewriter.setInsertionPointAfter(newWarpOp);
1582  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1583  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1585  // Insert into the distributed vector.
1586  Value newResult;
1587  if (distrSrcDim >= 0) {
1588  // Every lane inserts a small piece.
1589  newResult = rewriter.create<vector::InsertOp>(
1590  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1591  } else {
1592  // One lane inserts the entire source vector.
1593  int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1594  SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1595  SmallVector<int64_t> newPos = getAsIntegers(pos);
1596  // tid of inserting lane: pos / elementsPerLane
1597  Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1598  loc, newPos[distrDestDim] / elementsPerLane);
1599  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1600  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1601  // Insert position: pos % elementsPerLane
1602  newPos[distrDestDim] %= elementsPerLane;
1603  auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1604  Value newInsert = builder.create<vector::InsertOp>(
1605  loc, distributedSrc, distributedDest, newPos);
1606  builder.create<scf::YieldOp>(loc, newInsert);
1607  };
1608  auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1609  builder.create<scf::YieldOp>(loc, distributedDest);
1610  };
1611  newResult = rewriter
1612  .create<scf::IfOp>(loc, isInsertingLane,
1613  /*thenBuilder=*/insertingBuilder,
1614  /*elseBuilder=*/nonInsertingBuilder)
1615  .getResult(0);
1616  }
1618  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1619  return success();
1620  }
1621 };
1623 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1624 /// the scf.ForOp is the last operation in the region so that it doesn't change
1625 /// the order of execution. This creates a new scf.for region after the
1626 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1627 /// WarpExecuteOnLane0Op region. Example:
1628 /// ```
1629 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1630 /// ...
1631 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1632 /// -> (vector<128xf32>) {
1633 /// ...
1634 /// scf.yield %r : vector<128xf32>
1635 /// }
1636 /// vector.yield %v1 : vector<128xf32>
1637 /// }
1638 /// ```
1639 /// To:
1640 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1641 /// ...
1642 /// vector.yield %v : vector<128xf32>
1643 /// }
1644 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1645 /// -> (vector<4xf32>) {
1646 /// %iw = vector.warp_execute_on_lane_0(%laneid)
1647 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1648 /// ^bb0(%arg: vector<128xf32>):
1649 /// ...
1650 /// vector.yield %ir : vector<128xf32>
1651 /// }
1652 /// scf.yield %iw : vector<4xf32>
1653 /// }
1654 /// ```
1655 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
1657  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1658  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1659  distributionMapFn(std::move(fn)) {}
1661  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1662  PatternRewriter &rewriter) const override {
1663  auto yield = cast<vector::YieldOp>(
1664  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1665  // Only pick up forOp if it is the last op in the region.
1666  Operation *lastNode = yield->getPrevNode();
1667  auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1668  if (!forOp)
1669  return failure();
1670  // Collect Values that come from the warp op but are outside the forOp.
1671  // Those Value needs to be returned by the original warpOp and passed to the
1672  // new op.
1673  llvm::SmallSetVector<Value, 32> escapingValues;
1674  SmallVector<Type> inputTypes;
1675  SmallVector<Type> distTypes;
1677  forOp.getBodyRegion(), [&](OpOperand *operand) {
1678  Operation *parent = operand->get().getParentRegion()->getParentOp();
1679  if (warpOp->isAncestor(parent)) {
1680  if (!escapingValues.insert(operand->get()))
1681  return;
1682  Type distType = operand->get().getType();
1683  if (auto vecType = dyn_cast<VectorType>(distType)) {
1684  AffineMap map = distributionMapFn(operand->get());
1685  distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1686  }
1687  inputTypes.push_back(operand->get().getType());
1688  distTypes.push_back(distType);
1689  }
1690  });
1692  if (llvm::is_contained(distTypes, Type{}))
1693  return failure();
1695  SmallVector<size_t> newRetIndices;
1696  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1697  rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1698  newRetIndices);
1699  yield = cast<vector::YieldOp>(
1700  newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1702  SmallVector<Value> newOperands;
1703  SmallVector<unsigned> resultIdx;
1704  // Collect all the outputs coming from the forOp.
1705  for (OpOperand &yieldOperand : yield->getOpOperands()) {
1706  if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1707  continue;
1708  auto forResult = cast<OpResult>(yieldOperand.get());
1709  newOperands.push_back(
1710  newWarpOp.getResult(yieldOperand.getOperandNumber()));
1711  yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1712  resultIdx.push_back(yieldOperand.getOperandNumber());
1713  }
1715  OpBuilder::InsertionGuard g(rewriter);
1716  rewriter.setInsertionPointAfter(newWarpOp);
1718  // Create a new for op outside the region with a WarpExecuteOnLane0Op region
1719  // inside.
1720  auto newForOp = rewriter.create<scf::ForOp>(
1721  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1722  forOp.getStep(), newOperands);
1723  rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
1725  SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1726  newForOp.getRegionIterArgs().end());
1727  SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1728  forOp.getResultTypes().end());
1729  llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1730  for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1731  warpInput.push_back(newWarpOp.getResult(retIdx));
1732  argIndexMapping[escapingValues[i]] = warpInputType.size();
1733  warpInputType.push_back(inputTypes[i]);
1734  }
1735  auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1736  newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1737  newWarpOp.getWarpSize(), warpInput, warpInputType);
1739  SmallVector<Value> argMapping;
1740  argMapping.push_back(newForOp.getInductionVar());
1741  for (Value args : innerWarp.getBody()->getArguments()) {
1742  argMapping.push_back(args);
1743  }
1744  argMapping.resize(forOp.getBody()->getNumArguments());
1745  SmallVector<Value> yieldOperands;
1746  for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1747  yieldOperands.push_back(operand);
1748  rewriter.eraseOp(forOp.getBody()->getTerminator());
1749  rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1750  rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
1751  rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
1752  rewriter.setInsertionPointAfter(innerWarp);
1753  if (!innerWarp.getResults().empty())
1754  rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1755  rewriter.eraseOp(forOp);
1756  // Replace the warpOp result coming from the original ForOp.
1757  for (const auto &res : llvm::enumerate(resultIdx)) {
1758  rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1759  newForOp.getResult(res.index()));
1760  newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1761  }
1762  newForOp.walk([&](Operation *op) {
1763  for (OpOperand &operand : op->getOpOperands()) {
1764  auto it = argIndexMapping.find(operand.get());
1765  if (it == argIndexMapping.end())
1766  continue;
1767  operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1768  }
1769  });
1771  // Finally, hoist out any now uniform code from the inner warp op.
1772  mlir::vector::moveScalarUniformCode(innerWarp);
1773  return success();
1774  }
1776 private:
1777  DistributionMapFn distributionMapFn;
1778 };
1780 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1781 /// The vector is reduced in parallel. Currently limited to vector size matching
1782 /// the warpOp size. E.g.:
1783 /// ```
1784 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1785 /// %0 = "some_def"() : () -> (vector<32xf32>)
1786 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1787 /// vector_ext.yield %1 : f32
1788 /// }
1789 /// ```
1790 /// is lowered to:
1791 /// ```
1792 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1793 /// %1 = "some_def"() : () -> (vector<32xf32>)
1794 /// vector_ext.yield %1 : vector<32xf32>
1795 /// }
1796 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
1797 /// %r = ("warp.reduction %a")
1798 /// ```
1799 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
1800  WarpOpReduction(MLIRContext *context,
1801  DistributedReductionFn distributedReductionFn,
1802  PatternBenefit benefit = 1)
1803  : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
1804  distributedReductionFn(std::move(distributedReductionFn)) {}
1806  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1807  PatternRewriter &rewriter) const override {
1808  OpOperand *yieldOperand =
1809  getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1810  if (!yieldOperand)
1811  return failure();
1813  auto reductionOp =
1814  cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1815  auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1816  // Only rank 1 vectors supported.
1817  if (vectorType.getRank() != 1)
1818  return rewriter.notifyMatchFailure(
1819  warpOp, "Only rank 1 reductions can be distributed.");
1820  // Only warp_size-sized vectors supported.
1821  if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1822  return rewriter.notifyMatchFailure(
1823  warpOp, "Reduction vector dimension must match was size.");
1824  if (!reductionOp.getType().isIntOrFloat())
1825  return rewriter.notifyMatchFailure(
1826  warpOp, "Reduction distribution currently only supports floats and "
1827  "integer types.");
1829  int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1830  // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1831  unsigned operandIndex = yieldOperand->getOperandNumber();
1832  SmallVector<Value> yieldValues = {reductionOp.getVector()};
1833  SmallVector<Type> retTypes = {
1834  VectorType::get({numElements}, reductionOp.getType())};
1835  if (reductionOp.getAcc()) {
1836  yieldValues.push_back(reductionOp.getAcc());
1837  retTypes.push_back(reductionOp.getAcc().getType());
1838  }
1839  SmallVector<size_t> newRetIndices;
1840  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1841  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1842  rewriter.setInsertionPointAfter(newWarpOp);
1844  // Obtain data to reduce for a single lane.
1845  Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1846  // Distribute and reduce across threads.
1847  Value fullReduce =
1848  distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1849  reductionOp.getKind(), newWarpOp.getWarpSize());
1850  if (reductionOp.getAcc()) {
1851  fullReduce = vector::makeArithReduction(
1852  rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1853  newWarpOp.getResult(newRetIndices[1]));
1854  }
1855  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1856  return success();
1857  }
1859 private:
1860  DistributedReductionFn distributedReductionFn;
1861 };
1863 } // namespace
1866  RewritePatternSet &patterns,
1868  patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
1869 }
1871 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1872  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1873  unsigned maxNumElementsToExtract, PatternBenefit benefit) {
1874  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
1875  maxNumElementsToExtract, benefit);
1876 }
1878 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1879  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1880  const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1881  PatternBenefit readBenefit) {
1882  patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1883  patterns
1884  .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1885  WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
1886  WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
1887  patterns.getContext(), benefit);
1888  patterns.add<WarpOpExtractElement>(patterns.getContext(),
1889  warpShuffleFromIdxFn, benefit);
1890  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
1891  benefit);
1892 }
1894 void mlir::vector::populateDistributeReduction(
1895  RewritePatternSet &patterns,
1896  const DistributedReductionFn &distributedReductionFn,
1897  PatternBenefit benefit) {
1898  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
1899  benefit);
1900 }
1902 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1903  Block *body = warpOp.getBody();
1905  // Keep track of the ops we want to hoist.
1906  llvm::SmallSetVector<Operation *, 8> opsToMove;
1908  // Helper to check if a value is or will be defined outside of the region.
1909  auto isDefinedOutsideOfBody = [&](Value value) {
1910  auto *definingOp = value.getDefiningOp();
1911  return (definingOp && opsToMove.count(definingOp)) ||
1912  warpOp.isDefinedOutsideOfRegion(value);
1913  };
1915  // Do not use walk here, as we do not want to go into nested regions and hoist
1916  // operations from there.
1917  for (auto &op : body->without_terminator()) {
1918  bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1919  return isa<VectorType>(result.getType());
1920  });
1921  if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1922  opsToMove.insert(&op);
1923  }
1925  // Move all the ops marked as uniform outside of the region.
1926  for (Operation *op : opsToMove)
1927  op->moveBefore(warpOp);
1928 }
