MLIR  22.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinTypes.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVectorExtras.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <utility>
26 
27 using namespace mlir;
28 using namespace mlir::vector;
29 using namespace mlir::gpu;
30 
31 /// Currently the distribution map is implicit based on the vector shape. In the
32 /// future it will be part of the op.
33 /// Example:
34 /// ```
35 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
36 /// ...
37 /// gpu.yield %3 : vector<32x16x64xf32>
38 /// }
39 /// ```
40 /// Would have an implicit map of:
41 /// `(d0, d1, d2) -> (d0, d2)`
42 static AffineMap calculateImplicitMap(VectorType sequentialType,
43  VectorType distributedType) {
45  perm.reserve(1);
46  // Check which dimensions of the sequential type are different than the
47  // dimensions of the distributed type to know the distributed dimensions. Then
48  // associate each distributed dimension to an ID in order.
49  for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
50  if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
51  perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
52  }
53  auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
54  distributedType.getContext());
55  return map;
56 }
57 
58 /// Given a sequential and distributed vector type, returns the distributed
59 /// dimension. This function expects that only a single dimension is
60 /// distributed.
61 static int getDistributedDim(VectorType sequentialType,
62  VectorType distributedType) {
63  assert(sequentialType.getRank() == distributedType.getRank() &&
64  "sequential and distributed vector types must have the same rank");
65  int64_t distributedDim = -1;
66  for (int64_t i = 0; i < sequentialType.getRank(); ++i) {
67  if (distributedType.getDimSize(i) != sequentialType.getDimSize(i)) {
68  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
69  // support distributing multiple dimensions in the future.
70  assert(distributedDim == -1 && "found multiple distributed dims");
71  distributedDim = i;
72  }
73  }
74  return distributedDim;
75 }
76 
77 namespace {
78 
79 /// Helper struct to create the load / store operations that permit transit
80 /// through the parallel / sequential and the sequential / parallel boundaries
81 /// when performing `rewriteWarpOpToScfFor`.
82 ///
83 /// The vector distribution dimension is inferred from the vector types.
84 struct DistributedLoadStoreHelper {
85  DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
86  Value laneId, Value zero)
87  : sequentialVal(sequentialVal), distributedVal(distributedVal),
88  laneId(laneId), zero(zero) {
89  sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
90  distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
91  if (sequentialVectorType && distributedVectorType)
92  distributionMap =
93  calculateImplicitMap(sequentialVectorType, distributedVectorType);
94  }
95 
96  Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
97  int64_t distributedSize = distributedVectorType.getDimSize(index);
99  return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
100  ArrayRef<Value>{laneId});
101  }
102 
103  /// Create a store during the process of distributing the
104  /// `vector.warp_execute_on_thread_0` op.
105  /// Vector distribution assumes the following convention regarding the
106  /// temporary buffers that are created to transition values. This **must**
107  /// be properly specified in the `options.warpAllocationFn`:
108  /// 1. scalars of type T transit through a memref<1xT>.
109  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
110  Operation *buildStore(RewriterBase &b, Location loc, Value val,
111  Value buffer) {
112  assert((val == distributedVal || val == sequentialVal) &&
113  "Must store either the preregistered distributed or the "
114  "preregistered sequential value.");
115  // Scalar case can directly use memref.store.
116  if (!isa<VectorType>(val.getType()))
117  return memref::StoreOp::create(b, loc, val, buffer, zero);
118 
119  // Vector case must use vector::TransferWriteOp which will later lower to
120  // vector.store of memref.store depending on further lowerings.
121  int64_t rank = sequentialVectorType.getRank();
122  SmallVector<Value> indices(rank, zero);
123  if (val == distributedVal) {
124  for (auto dimExpr : distributionMap.getResults()) {
125  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
126  indices[index] = buildDistributedOffset(b, loc, index);
127  }
128  }
129  SmallVector<bool> inBounds(indices.size(), true);
130  return vector::TransferWriteOp::create(
131  b, loc, val, buffer, indices,
132  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
133  }
134 
135  /// Create a load during the process of distributing the
136  /// `vector.warp_execute_on_thread_0` op.
137  /// Vector distribution assumes the following convention regarding the
138  /// temporary buffers that are created to transition values. This **must**
139  /// be properly specified in the `options.warpAllocationFn`:
140  /// 1. scalars of type T transit through a memref<1xT>.
141  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
142  ///
143  /// When broadcastMode is true, the load is not distributed to account for
144  /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op.
145  ///
146  /// Example:
147  ///
148  /// ```
149  /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
150  /// gpu.yield %cst : f32
151  /// }
152  /// // Both types are f32. The constant %cst is broadcasted to all lanes.
153  /// ```
154  /// This behavior described in more detail in the documentation of the op.
155  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
156 
157  // Scalar case can directly use memref.store.
158  if (!isa<VectorType>(type))
159  return memref::LoadOp::create(b, loc, buffer, zero);
160 
161  // Other cases must be vector atm.
162  // Vector case must use vector::TransferReadOp which will later lower to
163  // vector.read of memref.read depending on further lowerings.
164  assert((type == distributedVectorType || type == sequentialVectorType) &&
165  "Must store either the preregistered distributed or the "
166  "preregistered sequential type.");
167  SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
168  if (type == distributedVectorType) {
169  for (auto dimExpr : distributionMap.getResults()) {
170  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
171  indices[index] = buildDistributedOffset(b, loc, index);
172  }
173  }
174  SmallVector<bool> inBounds(indices.size(), true);
175  return vector::TransferReadOp::create(
176  b, loc, cast<VectorType>(type), buffer, indices,
177  /*padding=*/std::nullopt,
178  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
179  }
180 
181  Value sequentialVal, distributedVal, laneId, zero;
182  VectorType sequentialVectorType, distributedVectorType;
183  AffineMap distributionMap;
184 };
185 
186 } // namespace
187 
188 // Clones `op` into a new operation that takes `operands` and returns
189 // `resultTypes`.
191  Location loc, Operation *op,
192  ArrayRef<Value> operands,
193  ArrayRef<Type> resultTypes) {
194  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
195  op->getAttrs());
196  return rewriter.create(res);
197 }
198 
199 namespace {
200 
201 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
202 /// thread `laneId` executes the entirety of the computation.
203 ///
204 /// After the transformation:
205 /// - the IR within the scf.if op can be thought of as executing sequentially
206 /// (from the point of view of threads along `laneId`).
207 /// - the IR outside of the scf.if op can be thought of as executing in
208 /// parallel (from the point of view of threads along `laneId`).
209 ///
210 /// Values that need to transit through the parallel / sequential and the
211 /// sequential / parallel boundaries do so via reads and writes to a temporary
212 /// memory location.
213 ///
214 /// The transformation proceeds in multiple steps:
215 /// 1. Create the scf.if op.
216 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
217 /// within the scf.if to transit the values captured from above.
218 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
219 /// consistent within the scf.if.
220 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
221 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to
222 /// transit the values returned by the op.
223 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
224 /// consistent after the scf.if.
225 /// 7. Perform late cleanups.
226 ///
227 /// All this assumes the vector distribution occurs along the most minor
228 /// distributed vector dimension.
229 struct WarpOpToScfIfPattern : public WarpDistributionPattern {
230  WarpOpToScfIfPattern(MLIRContext *context,
232  PatternBenefit benefit = 1)
233  : WarpDistributionPattern(context, benefit), options(options) {}
234 
235  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
236  PatternRewriter &rewriter) const override {
237  assert(warpOp.getBodyRegion().hasOneBlock() &&
238  "expected WarpOp with single block");
239  Block *warpOpBody = &warpOp.getBodyRegion().front();
240  Location loc = warpOp.getLoc();
241 
242  // Passed all checks. Start rewriting.
243  OpBuilder::InsertionGuard g(rewriter);
244  rewriter.setInsertionPoint(warpOp);
245 
246  // Step 1: Create scf.if op.
247  Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
248  Value isLane0 = arith::CmpIOp::create(
249  rewriter, loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
250  auto ifOp = scf::IfOp::create(rewriter, loc, isLane0,
251  /*withElseRegion=*/false);
252  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
253 
254  // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
255  // reads within the scf.if to transit the values captured from above.
256  SmallVector<Value> bbArgReplacements;
257  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
258  Value sequentialVal = warpOpBody->getArgument(it.index());
259  Value distributedVal = it.value();
260  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
261  warpOp.getLaneid(), c0);
262 
263  // Create buffer before the ifOp.
264  rewriter.setInsertionPoint(ifOp);
265  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
266  sequentialVal.getType());
267  // Store distributed vector into buffer, before the ifOp.
268  helper.buildStore(rewriter, loc, distributedVal, buffer);
269  // Load sequential vector from buffer, inside the ifOp.
270  rewriter.setInsertionPointToStart(ifOp.thenBlock());
271  bbArgReplacements.push_back(
272  helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
273  }
274 
275  // Step 3. Insert sync after all the stores and before all the loads.
276  if (!warpOp.getArgs().empty()) {
277  rewriter.setInsertionPoint(ifOp);
278  options.warpSyncronizationFn(loc, rewriter, warpOp);
279  }
280 
281  // Step 4. Move body of warpOp to ifOp.
282  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
283 
284  // Step 5. Insert appropriate writes within scf.if and reads after the
285  // scf.if to transit the values returned by the op.
286  // TODO: at this point, we can reuse the shared memory from previous
287  // buffers.
288  SmallVector<Value> replacements;
289  auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator());
290  Location yieldLoc = yieldOp.getLoc();
291  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
292  Value sequentialVal = it.value();
293  Value distributedVal = warpOp->getResult(it.index());
294  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
295  warpOp.getLaneid(), c0);
296 
297  // Create buffer before the ifOp.
298  rewriter.setInsertionPoint(ifOp);
299  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
300  sequentialVal.getType());
301 
302  // Store yielded value into buffer, inside the ifOp, before the
303  // terminator.
304  rewriter.setInsertionPoint(yieldOp);
305  helper.buildStore(rewriter, loc, sequentialVal, buffer);
306 
307  // Load distributed value from buffer, after the warpOp.
308  rewriter.setInsertionPointAfter(ifOp);
309  // Result type and yielded value type are the same. This is a broadcast.
310  // E.g.:
311  // %r = gpu.warp_execute_on_lane_0(...) -> (f32) {
312  // gpu.yield %cst : f32
313  // }
314  // Both types are f32. The constant %cst is broadcasted to all lanes.
315  // This is described in more detail in the documentation of the op.
316  replacements.push_back(
317  helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
318  }
319 
320  // Step 6. Insert sync after all the stores and before all the loads.
321  if (!yieldOp.getOperands().empty()) {
322  rewriter.setInsertionPointAfter(ifOp);
323  options.warpSyncronizationFn(loc, rewriter, warpOp);
324  }
325 
326  // Step 7. Delete terminator and add empty scf.yield.
327  rewriter.eraseOp(yieldOp);
328  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
329  scf::YieldOp::create(rewriter, yieldLoc);
330 
331  // Compute replacements for WarpOp results.
332  rewriter.replaceOp(warpOp, replacements);
333 
334  return success();
335  }
336 
337 private:
339 };
340 
341 /// Return the distributed vector type based on the original type and the
342 /// distribution map. The map is expected to have a dimension equal to the
343 /// original type rank and should be a projection where the results are the
344 /// distributed dimensions. The number of results should be equal to the number
345 /// of warp sizes which is currently limited to 1.
346 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
347 /// and a warp size of 16 would distribute the second dimension (associated to
348 /// d1) and return vector<16x2x64>
349 static VectorType getDistributedType(VectorType originalType, AffineMap map,
350  int64_t warpSize) {
351  SmallVector<int64_t> targetShape(originalType.getShape());
352  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
353  unsigned position = map.getDimPosition(i);
354  if (targetShape[position] % warpSize != 0) {
355  if (warpSize % targetShape[position] != 0) {
356  return VectorType();
357  }
358  warpSize /= targetShape[position];
359  targetShape[position] = 1;
360  continue;
361  }
362  targetShape[position] = targetShape[position] / warpSize;
363  warpSize = 1;
364  break;
365  }
366  if (warpSize != 1) {
367  return VectorType();
368  }
369  VectorType targetType =
370  VectorType::get(targetShape, originalType.getElementType());
371  return targetType;
372 }
373 
374 /// Given a warpOp that contains ops with regions, the corresponding op's
375 /// "inner" region and the distributionMapFn, get all values used by the op's
376 /// region that are defined within the warpOp, but outside the inner region.
377 /// Return the set of values, their types and their distributed types.
378 std::tuple<llvm::SmallSetVector<Value, 32>, SmallVector<Type>,
380 getInnerRegionEscapingValues(WarpExecuteOnLane0Op warpOp, Region &innerRegion,
381  DistributionMapFn distributionMapFn) {
382  llvm::SmallSetVector<Value, 32> escapingValues;
383  SmallVector<Type> escapingValueTypes;
384  SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
385  if (innerRegion.empty())
386  return {std::move(escapingValues), std::move(escapingValueTypes),
387  std::move(escapingValueDistTypes)};
388  mlir::visitUsedValuesDefinedAbove(innerRegion, [&](OpOperand *operand) {
389  Operation *parent = operand->get().getParentRegion()->getParentOp();
390  if (warpOp->isAncestor(parent)) {
391  if (!escapingValues.insert(operand->get()))
392  return;
393  Type distType = operand->get().getType();
394  if (auto vecType = dyn_cast<VectorType>(distType)) {
395  AffineMap map = distributionMapFn(operand->get());
396  distType = getDistributedType(vecType, map, warpOp.getWarpSize());
397  }
398  escapingValueTypes.push_back(operand->get().getType());
399  escapingValueDistTypes.push_back(distType);
400  }
401  });
402  return {std::move(escapingValues), std::move(escapingValueTypes),
403  std::move(escapingValueDistTypes)};
404 }
405 
406 /// Distribute transfer_write ops based on the affine map returned by
407 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
408 /// will not be distributed (it should be less than the warp size).
409 ///
410 /// Example:
411 /// ```
412 /// %0 = gpu.warp_execute_on_lane_0(%id){
413 /// ...
414 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
415 /// gpu.yield
416 /// }
417 /// ```
418 /// To
419 /// ```
420 /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
421 /// ...
422 /// gpu.yield %v : vector<32xf32>
423 /// }
424 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
425 struct WarpOpTransferWrite : public WarpDistributionPattern {
426  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
427  unsigned maxNumElementsToExtract, PatternBenefit b = 1)
428  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)),
429  maxNumElementsToExtract(maxNumElementsToExtract) {}
430 
431  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
432  /// are multiples of the distribution ratio are supported at the moment.
433  LogicalResult tryDistributeOp(RewriterBase &rewriter,
434  vector::TransferWriteOp writeOp,
435  WarpExecuteOnLane0Op warpOp) const {
436  VectorType writtenVectorType = writeOp.getVectorType();
437 
438  // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
439  // to separate it from the rest.
440  if (writtenVectorType.getRank() == 0)
441  return failure();
442 
443  // 2. Compute the distributed type.
444  AffineMap map = distributionMapFn(writeOp.getVector());
445  VectorType targetType =
446  getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
447  if (!targetType)
448  return failure();
449 
450  // 2.5 Compute the distributed type for the new mask;
451  VectorType maskType;
452  if (writeOp.getMask()) {
453  // TODO: Distribution of masked writes with non-trivial permutation maps
454  // requires the distribution of the mask to elementwise match the
455  // distribution of the permuted written vector. Currently the details
456  // of which lane is responsible for which element is captured strictly
457  // by shape information on the warp op, and thus requires materializing
458  // the permutation in IR.
459  if (!writeOp.getPermutationMap().isMinorIdentity())
460  return failure();
461  maskType =
462  getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
463  }
464 
465  // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
466  // the rest.
467  vector::TransferWriteOp newWriteOp =
468  cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
469 
470  // 4. Reindex the write using the distribution map.
471  auto newWarpOp =
472  newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
473 
474  // Delinearize the lane id based on the way threads are divided across the
475  // vector. To get the number of threads per vector dimension, divide the
476  // sequential size by the distributed size along each dim.
477  rewriter.setInsertionPoint(newWriteOp);
478  SmallVector<OpFoldResult> delinearizedIdSizes;
479  for (auto [seqSize, distSize] :
480  llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
481  assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
482  delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
483  }
484  SmallVector<Value> delinearized;
485  if (map.getNumResults() > 1) {
486  delinearized = mlir::affine::AffineDelinearizeIndexOp::create(
487  rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(),
488  delinearizedIdSizes)
489  .getResults();
490  } else {
491  // If there is only one map result, we can elide the delinearization
492  // op and use the lane id directly.
493  delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
494  }
495 
496  AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
497  Location loc = newWriteOp.getLoc();
498  SmallVector<Value> indices(newWriteOp.getIndices().begin(),
499  newWriteOp.getIndices().end());
500  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
501  AffineExpr d0, d1;
502  bindDims(newWarpOp.getContext(), d0, d1);
503  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
504  if (!indexExpr)
505  continue;
506  unsigned indexPos = indexExpr.getPosition();
507  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
508  Value laneId = delinearized[vectorPos];
509  auto scale =
510  rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
511  indices[indexPos] = affine::makeComposedAffineApply(
512  rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
513  }
514  newWriteOp.getIndicesMutable().assign(indices);
515 
516  return success();
517  }
518 
519  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
520  LogicalResult tryExtractOp(RewriterBase &rewriter,
521  vector::TransferWriteOp writeOp,
522  WarpExecuteOnLane0Op warpOp) const {
523  Location loc = writeOp.getLoc();
524  VectorType vecType = writeOp.getVectorType();
525 
526  if (vecType.getNumElements() > maxNumElementsToExtract) {
527  return rewriter.notifyMatchFailure(
528  warpOp,
529  llvm::formatv(
530  "writes more elements ({0}) than allowed to extract ({1})",
531  vecType.getNumElements(), maxNumElementsToExtract));
532  }
533 
534  // Do not process warp ops that contain only TransferWriteOps.
535  if (llvm::all_of(warpOp.getOps(),
536  llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>))
537  return failure();
538 
539  SmallVector<Value> yieldValues = {writeOp.getVector()};
540  SmallVector<Type> retTypes = {vecType};
541  SmallVector<size_t> newRetIndices;
542  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
543  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
544  rewriter.setInsertionPointAfter(newWarpOp);
545 
546  // Create a second warp op that contains only writeOp.
547  auto secondWarpOp = WarpExecuteOnLane0Op::create(rewriter, loc, TypeRange(),
548  newWarpOp.getLaneid(),
549  newWarpOp.getWarpSize());
550  Block &body = secondWarpOp.getBodyRegion().front();
551  rewriter.setInsertionPointToStart(&body);
552  auto newWriteOp =
553  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
554  newWriteOp.getValueToStoreMutable().assign(
555  newWarpOp.getResult(newRetIndices[0]));
556  rewriter.eraseOp(writeOp);
557  gpu::YieldOp::create(rewriter, newWarpOp.getLoc());
558  return success();
559  }
560 
561  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
562  PatternRewriter &rewriter) const override {
563  gpu::YieldOp yield = warpOp.getTerminator();
564  Operation *lastNode = yield->getPrevNode();
565  auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
566  if (!writeOp)
567  return failure();
568 
569  Value maybeMask = writeOp.getMask();
570  if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
571  return writeOp.getVector() == value ||
572  (maybeMask && maybeMask == value) ||
573  warpOp.isDefinedOutsideOfRegion(value);
574  }))
575  return failure();
576 
577  if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
578  return success();
579 
580  // Masked writes not supported for extraction.
581  if (writeOp.getMask())
582  return failure();
583 
584  if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
585  return success();
586 
587  return failure();
588  }
589 
590 private:
591  /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp
592  /// execute op with the proper return type. The new write op is updated to
593  /// write the result of the new warp execute op. The old `writeOp` is deleted.
594  vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
595  WarpExecuteOnLane0Op warpOp,
596  vector::TransferWriteOp writeOp,
597  VectorType targetType,
598  VectorType maybeMaskType) const {
599  assert(writeOp->getParentOp() == warpOp &&
600  "write must be nested immediately under warp");
601  OpBuilder::InsertionGuard g(rewriter);
602  SmallVector<size_t> newRetIndices;
603  WarpExecuteOnLane0Op newWarpOp;
604  if (maybeMaskType) {
605  newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
606  rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
607  TypeRange{targetType, maybeMaskType}, newRetIndices);
608  } else {
609  newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
610  rewriter, warpOp, ValueRange{{writeOp.getVector()}},
611  TypeRange{targetType}, newRetIndices);
612  }
613  rewriter.setInsertionPointAfter(newWarpOp);
614  auto newWriteOp =
615  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
616  rewriter.eraseOp(writeOp);
617  newWriteOp.getValueToStoreMutable().assign(
618  newWarpOp.getResult(newRetIndices[0]));
619  if (maybeMaskType)
620  newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
621  return newWriteOp;
622  }
623 
624  DistributionMapFn distributionMapFn;
625  unsigned maxNumElementsToExtract = 1;
626 };
627 
628 /// Sink out elementwise op feeding into a warp op yield.
629 /// ```
630 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
631 /// ...
632 /// %3 = arith.addf %1, %2 : vector<32xf32>
633 /// gpu.yield %3 : vector<32xf32>
634 /// }
635 /// ```
636 /// To
637 /// ```
638 /// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
639 /// vector<1xf32>, vector<1xf32>) {
640 /// ...
641 /// %4 = arith.addf %2, %3 : vector<32xf32>
642 /// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
643 /// vector<32xf32>
644 /// }
645 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
646 struct WarpOpElementwise : public WarpDistributionPattern {
647  using Base::Base;
648  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
649  PatternRewriter &rewriter) const override {
650  OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
652  });
653  if (!yieldOperand)
654  return failure();
655 
656  Operation *elementWise = yieldOperand->get().getDefiningOp();
657  unsigned operandIndex = yieldOperand->getOperandNumber();
658  Value distributedVal = warpOp.getResult(operandIndex);
659  SmallVector<Value> yieldValues;
660  SmallVector<Type> retTypes;
661  Location loc = warpOp.getLoc();
662  for (OpOperand &operand : elementWise->getOpOperands()) {
663  Type targetType;
664  if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
665  // If the result type is a vector, the operands must also be vectors.
666  auto operandType = cast<VectorType>(operand.get().getType());
667  targetType =
668  VectorType::get(vecType.getShape(), operandType.getElementType());
669  } else {
670  auto operandType = operand.get().getType();
671  assert(!isa<VectorType>(operandType) &&
672  "unexpected yield of vector from op with scalar result type");
673  targetType = operandType;
674  }
675  retTypes.push_back(targetType);
676  yieldValues.push_back(operand.get());
677  }
678  SmallVector<size_t> newRetIndices;
679  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
680  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
681  rewriter.setInsertionPointAfter(newWarpOp);
682  SmallVector<Value> newOperands(elementWise->getOperands().begin(),
683  elementWise->getOperands().end());
684  for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
685  newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
686  }
687  OpBuilder::InsertionGuard g(rewriter);
688  rewriter.setInsertionPointAfter(newWarpOp);
690  rewriter, loc, elementWise, newOperands,
691  {newWarpOp.getResult(operandIndex).getType()});
692  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
693  newOp->getResult(0));
694  return success();
695  }
696 };
697 
698 /// Sink out splat constant op feeding into a warp op yield.
699 /// ```
700 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
701 /// ...
702 /// %cst = arith.constant dense<2.0> : vector<32xf32>
703 /// gpu.yield %cst : vector<32xf32>
704 /// }
705 /// ```
706 /// To
707 /// ```
708 /// gpu.warp_execute_on_lane_0(%arg0 {
709 /// ...
710 /// }
711 /// %0 = arith.constant dense<2.0> : vector<1xf32>
712 struct WarpOpConstant : public WarpDistributionPattern {
713  using Base::Base;
714  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
715  PatternRewriter &rewriter) const override {
716  OpOperand *yieldOperand =
717  getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
718  if (!yieldOperand)
719  return failure();
720  auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
721  auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
722  if (!dense)
723  return failure();
724  // Notify the rewriter that the warp op is changing (see the comment on
725  // the WarpOpTransferRead pattern).
726  rewriter.startOpModification(warpOp);
727  unsigned operandIndex = yieldOperand->getOperandNumber();
728  Attribute scalarAttr = dense.getSplatValue<Attribute>();
729  auto newAttr = DenseElementsAttr::get(
730  cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
731  Location loc = warpOp.getLoc();
732  rewriter.setInsertionPointAfter(warpOp);
733  Value distConstant = arith::ConstantOp::create(rewriter, loc, newAttr);
734  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
735  rewriter.finalizeOpModification(warpOp);
736  return success();
737  }
738 };
739 
740 /// Sink out step op feeding into a warp op yield.
741 /// Vector step op is treated similar to arith.constant, apart from
742 /// the result that represents a sequence [0, vec_size).
743 /// Due to the to vec_size == warp_size limitation,
744 /// we can simply wrap the lane id into a vector (i.e., broadcast).
745 /// Supporting vec_size != warp_size may involve preserving the step
746 /// result and using additional arith ops (the exact details are TBD).
747 /// ```
748 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) {
749 /// ...
750 /// %cst = vector.step : vector<32xindex>
751 /// gpu.yield %cst : vector<1xindex>
752 /// }
753 /// ```
754 /// To
755 /// ```
756 /// gpu.warp_execute_on_lane_0(%arg0) {
757 /// ...
758 /// }
759 /// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex>
760 struct WarpOpStep final : public WarpDistributionPattern {
761  using Base::Base;
762  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
763  PatternRewriter &rewriter) const override {
764  OpOperand *yieldOperand =
765  getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>);
766  if (!yieldOperand)
767  return failure();
768  const unsigned operandIdx = yieldOperand->getOperandNumber();
769  auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>();
770  VectorType resTy = stepOp.getResult().getType();
771  if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize()))
772  return rewriter.notifyMatchFailure(
773  warpOp,
774  llvm::formatv("Expected result size ({0}) to be of warp size ({1})",
775  resTy.getNumElements(), warpOp.getWarpSize()));
776  VectorType newVecTy =
777  cast<VectorType>(warpOp.getResult(operandIdx).getType());
778  rewriter.setInsertionPointAfter(warpOp);
779  Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(),
780  newVecTy, warpOp.getLaneid());
781  rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec);
782  return success();
783  }
784 };
785 
786 /// Sink out transfer_read op feeding into a warp op yield.
787 /// ```
788 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
789 /// ...
790 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
791 // vector<32xf32>
792 /// gpu.yield %2 : vector<32xf32>
793 /// }
794 /// ```
795 /// To
796 /// ```
797 /// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
798 /// vector<1xf32>, vector<1xf32>) {
799 /// ...
800 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
801 /// vector<32xf32> gpu.yield %2 : vector<32xf32>
802 /// }
803 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
804 struct WarpOpTransferRead : public WarpDistributionPattern {
805  using Base::Base;
806  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
807  PatternRewriter &rewriter) const override {
808  // Try to find a distributable yielded read. Note that this pattern can
809  // still fail at the end after distribution, in which case this might have
810  // missed another distributable read.
811  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
812  // Don't duplicate transfer_read ops when distributing.
813  return isa<vector::TransferReadOp>(op) && op->hasOneUse();
814  });
815  if (!operand)
816  return rewriter.notifyMatchFailure(
817  warpOp, "warp result is not a vector.transfer_read op");
818  auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
819 
820  // Source must be defined outside of the region.
821  if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
822  return rewriter.notifyMatchFailure(
823  read, "source must be defined outside of the region");
824 
825  unsigned operandIndex = operand->getOperandNumber();
826  Value distributedVal = warpOp.getResult(operandIndex);
827 
828  SmallVector<Value, 4> indices(read.getIndices().begin(),
829  read.getIndices().end());
830  auto sequentialType = cast<VectorType>(read.getResult().getType());
831  auto distributedType = cast<VectorType>(distributedVal.getType());
832  AffineMap map = calculateImplicitMap(sequentialType, distributedType);
833  AffineMap indexMap = map.compose(read.getPermutationMap());
834 
835  // Try to delinearize the lane ID to match the rank expected for
836  // distribution.
837  SmallVector<Value> delinearizedIds;
838  if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
839  distributedType.getShape(), warpOp.getWarpSize(),
840  warpOp.getLaneid(), delinearizedIds)) {
841  return rewriter.notifyMatchFailure(
842  read, "cannot delinearize lane ID for distribution");
843  }
844  assert(!delinearizedIds.empty() || map.getNumResults() == 0);
845 
846  // Distribute indices and the mask (if present).
847  OpBuilder::InsertionGuard g(rewriter);
848  SmallVector<Value> additionalResults(indices.begin(), indices.end());
849  SmallVector<Type> additionalResultTypes(indices.size(),
850  rewriter.getIndexType());
851  additionalResults.push_back(read.getPadding());
852  additionalResultTypes.push_back(read.getPadding().getType());
853 
854  bool hasMask = false;
855  if (read.getMask()) {
856  hasMask = true;
857  // TODO: Distribution of masked reads with non-trivial permutation maps
858  // requires the distribution of the mask to elementwise match the
859  // distribution of the permuted written vector. Currently the details
860  // of which lane is responsible for which element is captured strictly
861  // by shape information on the warp op, and thus requires materializing
862  // the permutation in IR.
863  if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
864  return rewriter.notifyMatchFailure(
865  read, "non-trivial permutation maps not supported");
866  VectorType maskType =
867  getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
868  additionalResults.push_back(read.getMask());
869  additionalResultTypes.push_back(maskType);
870  }
871 
872  SmallVector<size_t> newRetIndices;
873  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
874  rewriter, warpOp, additionalResults, additionalResultTypes,
875  newRetIndices);
876  distributedVal = newWarpOp.getResult(operandIndex);
877 
878  // Distributed indices were appended first.
879  SmallVector<Value> newIndices;
880  for (int64_t i = 0, e = indices.size(); i < e; ++i)
881  newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
882 
883  rewriter.setInsertionPointAfter(newWarpOp);
884  for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
885  AffineExpr d0, d1;
886  bindDims(read.getContext(), d0, d1);
887  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
888  if (!indexExpr)
889  continue;
890  unsigned indexPos = indexExpr.getPosition();
891  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
892  int64_t scale = distributedType.getDimSize(vectorPos);
893  newIndices[indexPos] = affine::makeComposedAffineApply(
894  rewriter, read.getLoc(), d0 + scale * d1,
895  {newIndices[indexPos], delinearizedIds[vectorPos]});
896  }
897 
898  // Distributed padding value was appended right after the indices.
899  Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
900  // Distributed mask value was added at the end (if the op has a mask).
901  Value newMask =
902  hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
903  : Value();
904  auto newRead = vector::TransferReadOp::create(
905  rewriter, read.getLoc(), distributedVal.getType(), read.getBase(),
906  newIndices, read.getPermutationMapAttr(), newPadding, newMask,
907  read.getInBoundsAttr());
908 
909  rewriter.replaceAllUsesWith(distributedVal, newRead);
910  return success();
911  }
912 };
913 
914 /// Remove any result that has no use along with the matching yieldOp operand.
915 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
916 struct WarpOpDeadResult : public WarpDistributionPattern {
917  using Base::Base;
918  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
919  PatternRewriter &rewriter) const override {
920  SmallVector<Type> newResultTypes;
921  newResultTypes.reserve(warpOp->getNumResults());
922  SmallVector<Value> newYieldValues;
923  newYieldValues.reserve(warpOp->getNumResults());
924  DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
925  DenseMap<OpResult, int64_t> dedupResultPositionMap;
926  gpu::YieldOp yield = warpOp.getTerminator();
927 
928  // Some values may be yielded multiple times and correspond to multiple
929  // results. Deduplicating occurs by taking each result with its matching
930  // yielded value, and:
931  // 1. recording the unique first position at which the value is yielded.
932  // 2. recording for the result, the first position at which the dedup'ed
933  // value is yielded.
934  // 3. skipping from the new result types / new yielded values any result
935  // that has no use or whose yielded value has already been seen.
936  for (OpResult result : warpOp.getResults()) {
937  Value yieldOperand = yield.getOperand(result.getResultNumber());
938  auto it = dedupYieldOperandPositionMap.insert(
939  std::make_pair(yieldOperand, newResultTypes.size()));
940  dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
941  if (result.use_empty() || !it.second)
942  continue;
943  newResultTypes.push_back(result.getType());
944  newYieldValues.push_back(yieldOperand);
945  }
946  // No modification, exit early.
947  if (yield.getNumOperands() == newYieldValues.size())
948  return failure();
949  // Move the body of the old warpOp to a new warpOp.
950  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
951  rewriter, warpOp, newYieldValues, newResultTypes);
952 
953  // Simplify the new warp op after dropping dead results.
954  newWarpOp.getBody()->walk([&](Operation *op) {
955  if (isOpTriviallyDead(op))
956  rewriter.eraseOp(op);
957  });
958 
959  // Replace results of the old warpOp by the new, deduplicated results.
960  SmallVector<Value> newValues;
961  newValues.reserve(warpOp->getNumResults());
962  for (OpResult result : warpOp.getResults()) {
963  if (result.use_empty())
964  newValues.push_back(Value());
965  else
966  newValues.push_back(
967  newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
968  }
969  rewriter.replaceOp(warpOp, newValues);
970  return success();
971  }
972 };
973 
974 // If an operand is directly yielded out of the region we can forward it
975 // directly and it doesn't need to go through the region.
976 struct WarpOpForwardOperand : public WarpDistributionPattern {
977  using Base::Base;
978  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
979  PatternRewriter &rewriter) const override {
980  gpu::YieldOp yield = warpOp.getTerminator();
981  Value valForwarded;
982  unsigned resultIndex;
983  for (OpOperand &operand : yield->getOpOperands()) {
984  Value result = warpOp.getResult(operand.getOperandNumber());
985  if (result.use_empty())
986  continue;
987 
988  // Assume all the values coming from above are uniform.
989  if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
990  if (result.getType() != operand.get().getType())
991  continue;
992  valForwarded = operand.get();
993  resultIndex = operand.getOperandNumber();
994  break;
995  }
996  auto arg = dyn_cast<BlockArgument>(operand.get());
997  if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
998  continue;
999  Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1000  if (result.getType() != warpOperand.getType())
1001  continue;
1002  valForwarded = warpOperand;
1003  resultIndex = operand.getOperandNumber();
1004  break;
1005  }
1006  if (!valForwarded)
1007  return failure();
1008  // Notify the rewriter that the warp op is changing (see the comment on
1009  // the WarpOpTransferRead pattern).
1010  rewriter.startOpModification(warpOp);
1011  rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
1012  rewriter.finalizeOpModification(warpOp);
1013  return success();
1014  }
1015 };
1016 
1017 struct WarpOpBroadcast : public WarpDistributionPattern {
1018  using Base::Base;
1019  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1020  PatternRewriter &rewriter) const override {
1021  OpOperand *operand =
1022  getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1023  if (!operand)
1024  return failure();
1025  unsigned int operandNumber = operand->getOperandNumber();
1026  auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
1027  Location loc = broadcastOp.getLoc();
1028  auto destVecType =
1029  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1030  Value broadcastSrc = broadcastOp.getSource();
1031  Type broadcastSrcType = broadcastSrc.getType();
1032 
1033  // Check that the broadcast actually spans a set of values uniformly across
1034  // all threads. In other words, check that each thread can reconstruct
1035  // their own broadcast.
1036  // For that we simply check that the broadcast we want to build makes sense.
1037  if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
1039  return failure();
1040  SmallVector<size_t> newRetIndices;
1041  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1042  rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1043  rewriter.setInsertionPointAfter(newWarpOp);
1044  Value broadcasted = vector::BroadcastOp::create(
1045  rewriter, loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1046  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1047  broadcasted);
1048  return success();
1049  }
1050 };
1051 
1052 /// Pattern to move shape cast out of the warp op. shape cast is basically a
1053 /// no-op for warp distribution; we need to handle the shape though.
1054 struct WarpOpShapeCast : public WarpDistributionPattern {
1055  using Base::Base;
1056  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1057  PatternRewriter &rewriter) const override {
1058  OpOperand *operand =
1059  getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1060  if (!operand)
1061  return failure();
1062 
1063  auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
1064 
1065  unsigned int operandNumber = operand->getOperandNumber();
1066  auto castDistributedType =
1067  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1068  VectorType castOriginalType = oldCastOp.getSourceVectorType();
1069  VectorType castResultType = castDistributedType;
1070 
1071  // We expect the distributed type to have a smaller rank than the original
1072  // type. Prepend with size-one dimensions to make them the same.
1073  unsigned castDistributedRank = castDistributedType.getRank();
1074  unsigned castOriginalRank = castOriginalType.getRank();
1075  if (castDistributedRank < castOriginalRank) {
1076  SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1077  llvm::append_range(shape, castDistributedType.getShape());
1078  castDistributedType =
1079  VectorType::get(shape, castDistributedType.getElementType());
1080  }
1081 
1082  SmallVector<size_t> newRetIndices;
1083  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1084  rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1085  newRetIndices);
1086  rewriter.setInsertionPointAfter(newWarpOp);
1087  Value newCast = vector::ShapeCastOp::create(
1088  rewriter, oldCastOp.getLoc(), castResultType,
1089  newWarpOp->getResult(newRetIndices[0]));
1090  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1091  return success();
1092  }
1093 };
1094 
1095 /// Sink out vector.create_mask op feeding into a warp op yield.
1096 /// ```
1097 /// %0 = ...
1098 /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1099 /// ...
1100 /// %mask = vector.create_mask %0 : vector<32xi1>
1101 /// gpu.yield %mask : vector<32xi1>
1102 /// }
1103 /// ```
1104 /// To
1105 /// ```
1106 /// %0 = ...
1107 /// gpu.warp_execute_on_lane_0(%arg0) {
1108 /// ...
1109 /// }
1110 /// %cmp = arith.cmpi ult, %laneid, %0
1111 /// %ub = arith.select %cmp, %c0, %c1
1112 /// %1 = vector.create_mask %ub : vector<1xi1>
1113 struct WarpOpCreateMask : public WarpDistributionPattern {
1114  using Base::Base;
1115  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1116  PatternRewriter &rewriter) const override {
1117  OpOperand *yieldOperand =
1118  getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1119  if (!yieldOperand)
1120  return failure();
1121 
1122  auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1123 
1124  // Early exit if any values needed for calculating the new mask indices
1125  // are defined inside the warp op.
1126  if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1127  return warpOp.isDefinedOutsideOfRegion(value);
1128  }))
1129  return failure();
1130 
1131  Location loc = mask.getLoc();
1132  unsigned operandIndex = yieldOperand->getOperandNumber();
1133 
1134  auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1135  VectorType seqType = mask.getVectorType();
1136  ArrayRef<int64_t> seqShape = seqType.getShape();
1137  ArrayRef<int64_t> distShape = distType.getShape();
1138 
1139  rewriter.setInsertionPointAfter(warpOp);
1140 
1141  // Delinearize the lane ID for constructing the distributed mask sizes.
1142  SmallVector<Value> delinearizedIds;
1143  if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1144  warpOp.getWarpSize(), warpOp.getLaneid(),
1145  delinearizedIds))
1146  return rewriter.notifyMatchFailure(
1147  mask, "cannot delinearize lane ID for distribution");
1148  assert(!delinearizedIds.empty());
1149 
1150  // Notify the rewriter that the warp op is changing (see the comment on
1151  // the WarpOpTransferRead pattern).
1152  rewriter.startOpModification(warpOp);
1153 
1154  AffineExpr s0, s1;
1155  bindSymbols(rewriter.getContext(), s0, s1);
1156  SmallVector<Value> newOperands;
1157  for (int i = 0, e = distShape.size(); i < e; ++i) {
1158  // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1159  // find the distance from the largest mask index owned by this lane to the
1160  // original mask size. `vector.create_mask` implicitly clamps mask
1161  // operands to the range [0, mask_vector_size[i]], or in other words, the
1162  // mask sizes are always in the range [0, mask_vector_size[i]).
1164  rewriter, loc, s1 - s0 * distShape[i],
1165  {delinearizedIds[i], mask.getOperand(i)});
1166  newOperands.push_back(maskDimIdx);
1167  }
1168 
1169  auto newMask =
1170  vector::CreateMaskOp::create(rewriter, loc, distType, newOperands);
1171  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1172  rewriter.finalizeOpModification(warpOp);
1173  return success();
1174  }
1175 };
1176 
1177 /// Sink out insert_strided_slice op feeding into a warp op yield.
1178 /// ```
1179 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<8x1xf32>) {
1180 /// ...
1181 /// %src = ... : vector<4x32xf32>
1182 /// %dest = ... : vector<8x32xf32>
1183 /// %insert = vector.insert_strided_slice %src, %dest, offsets = [0, 0],
1184 /// strides = [1, 1] : vector<4x32xf32> into vector<8x32xf32>
1185 /// gpu.yield %insert : vector<8x32xf32>
1186 /// }
1187 /// ```
1188 /// To
1189 /// ```
1190 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4x1xf32>,
1191 /// vector<8x1xf32>) {
1192 /// ...
1193 /// %src = ... : vector<4x32xf32>
1194 /// %dest = ... : vector<8x32xf32>
1195 /// gpu.yield %src, %dest : vector<4x16xf32>, vector<8x16xf32>
1196 /// }
1197 /// %insert = vector.insert_strided_slice %0#0, %0#1,
1198 /// offsets = [0, 0], strides = [1, 1] : vector<4x1xf32> into vector<8x1xf32>
1199 /// ```
1200 /// NOTE: Current support assumes that both src and dest vectors are distributed
1201 /// to lanes and sinking the insert op does not require any cross lane
1202 /// communication.
1203 struct WarpOpInsertStridedSlice : public WarpDistributionPattern {
1204  using Base::Base;
1205  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1206  PatternRewriter &rewriter) const override {
1207  OpOperand *operand =
1208  getWarpResult(warpOp, llvm::IsaPred<vector::InsertStridedSliceOp>);
1209  if (!operand)
1210  return failure();
1211  unsigned int operandNumber = operand->getOperandNumber();
1212  auto insertOp =
1213  operand->get().getDefiningOp<vector::InsertStridedSliceOp>();
1214  auto distributedType =
1215  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1216  // Distributed type must be 2D or higher.
1217  // TODO: Support 1D distributed types.
1218  if (distributedType.getRank() < 2)
1219  return rewriter.notifyMatchFailure(
1220  insertOp, "result vector type must be 2D or higher");
1221  // Find the distributed dimension of the dest vector. There should be
1222  // exactly one.
1223  auto yieldedType = cast<VectorType>(operand->get().getType());
1224  int64_t destDistributedDim =
1225  getDistributedDim(yieldedType, distributedType);
1226  assert(destDistributedDim != -1 && "could not find distributed dimension");
1227 
1228  VectorType srcType = insertOp.getSourceVectorType();
1229  VectorType destType = insertOp.getDestVectorType();
1230  // Currently we require that both source (kD) and dest (nD) vectors are
1231  // distributed. This requires that distributedDim (d) is contained in the
1232  // last k dims of the dest vector (d >= n - k).
1233  // TODO: Add support for case where source vector is not distributed.
1234  int64_t sourceDistributedDim =
1235  destDistributedDim - (destType.getRank() - srcType.getRank());
1236  if (sourceDistributedDim < 0)
1237  return rewriter.notifyMatchFailure(
1238  insertOp,
1239  "distributed dimension must be in the last k dims of dest vector");
1240  // Distributed dimension must be fully inserted.
1241  if (srcType.getDimSize(sourceDistributedDim) !=
1242  destType.getDimSize(destDistributedDim))
1243  return rewriter.notifyMatchFailure(
1244  insertOp, "distributed dimension must be fully inserted");
1245  SmallVector<int64_t> newSourceDistShape(
1246  insertOp.getSourceVectorType().getShape());
1247  newSourceDistShape[sourceDistributedDim] =
1248  distributedType.getDimSize(destDistributedDim);
1249  auto newSourceTy =
1250  VectorType::get(newSourceDistShape, distributedType.getElementType());
1251  VectorType newDestTy = distributedType;
1252  SmallVector<size_t> newRetIndices;
1253  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1254  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1255  {newSourceTy, newDestTy}, newRetIndices);
1256  rewriter.setInsertionPointAfter(newWarpOp);
1257  Value distributedSource = newWarpOp->getResult(newRetIndices[0]);
1258  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1259  // Create a new insert strided slice op that inserts distributed source into
1260  // distributed dest.
1261  Value newInsert = vector::InsertStridedSliceOp::create(
1262  rewriter, insertOp.getLoc(), distributedDest.getType(),
1263  distributedSource, distributedDest, insertOp.getOffsets(),
1264  insertOp.getStrides());
1265  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newInsert);
1266  return success();
1267  }
1268 };
1269 
1270 /// Sink out extract_strided_slice op feeding into a warp op yield.
1271 /// ```
1272 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<16x1xf32>) {
1273 /// ...
1274 /// %src = ... : vector<64x32xf32>
1275 /// %extract = vector.extract_strided_slice %src, offsets = [0], sizes = [16],
1276 /// strides = [1] : vector<64x32xf32> to vector<16x32xf32>
1277 /// gpu.yield %extract : vector<16x32xf32>
1278 /// }
1279 /// ```
1280 /// To
1281 /// ```
1282 /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<64x1xf32>) {
1283 /// ...
1284 /// %src = ... : vector<64x32xf32>
1285 /// gpu.yield %src : vector<64x32xf32>
1286 /// }
1287 /// %extract = vector.extract_strided_slice %0, offsets = [0], sizes = [16],
1288 /// strides = [1] : vector<64x1xf32> to vector<16x1xf32>
1289 /// ```
1290 /// NOTE: Current support assumes that the extraction happens only on non
1291 /// distributed dimensions (does not require cross lane communication).
1292 struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
1293  using Base::Base;
1294  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1295  PatternRewriter &rewriter) const override {
1296  OpOperand *operand =
1297  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractStridedSliceOp>);
1298  if (!operand)
1299  return failure();
1300  unsigned int operandNumber = operand->getOperandNumber();
1301  auto extractOp =
1302  operand->get().getDefiningOp<vector::ExtractStridedSliceOp>();
1303  auto distributedType =
1304  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1305  // Distributed type must be 2D or higher.
1306  // TODO: Support 1D distributed types.
1307  if (distributedType.getRank() < 2)
1308  return rewriter.notifyMatchFailure(
1309  extractOp, "result vector type must be 2D or higher");
1310 
1311  // Find the distributed dimension. There should be exactly one.
1312  auto yieldedType = cast<VectorType>(operand->get().getType());
1313  int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1314  assert(distributedDim != -1 && "could not find distributed dimension");
1315 
1316  int64_t numOfExtractedDims =
1317  static_cast<int64_t>(extractOp.getSizes().size());
1318  // If the distributed dim is included in the extracted dims, then we make
1319  // sure distributed dim is fully extracted. If distributed dim is not
1320  // included in extracted dims, it is guaranteed to be fully extracted (i.e.
1321  // distributed dim comes after all the extracted dims)
1322  // TODO: Partial extraction from distributed dimension require cross lane
1323  // communication.
1324  if (distributedDim < numOfExtractedDims) {
1325  int64_t distributedDimOffset =
1326  llvm::cast<IntegerAttr>(extractOp.getOffsets()[distributedDim])
1327  .getInt();
1328  int64_t distributedDimSize =
1329  llvm::cast<IntegerAttr>(extractOp.getSizes()[distributedDim])
1330  .getInt();
1331  if (distributedDimOffset != 0 ||
1332  distributedDimSize != yieldedType.getDimSize(distributedDim))
1333  return rewriter.notifyMatchFailure(
1334  extractOp, "distributed dimension must be fully extracted");
1335  }
1336  SmallVector<int64_t> newDistributedShape(
1337  extractOp.getSourceVectorType().getShape());
1338  newDistributedShape[distributedDim] =
1339  distributedType.getDimSize(distributedDim);
1340  auto newDistributedType =
1341  VectorType::get(newDistributedShape, distributedType.getElementType());
1342  SmallVector<size_t> newRetIndices;
1343  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1344  rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1345  newRetIndices);
1346  rewriter.setInsertionPointAfter(newWarpOp);
1347  SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
1348  extractOp.getSizes(), [](Attribute attr) { return attr; });
1349  // Update the distributed sizes to match the distributed type.
1350  if (distributedDim < static_cast<int64_t>(distributedSizes.size()))
1351  distributedSizes[distributedDim] = rewriter.getI64IntegerAttr(
1352  distributedType.getDimSize(distributedDim));
1353 
1354  // Create a new extract strided slice op that extracts from the
1355  // distributed vector.
1356  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1357  Value newExtract = vector::ExtractStridedSliceOp::create(
1358  rewriter, extractOp.getLoc(), distributedType, distributedVec,
1359  extractOp.getOffsets(),
1360  ArrayAttr::get(rewriter.getContext(), distributedSizes),
1361  extractOp.getStrides());
1362  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1363  newExtract);
1364  return success();
1365  }
1366 };
1367 
1368 /// Pattern to move out vector.extract of single element vector. Those don't
1369 /// need to be distributed and can just be propagated outside of the region.
1370 struct WarpOpExtract : public WarpDistributionPattern {
1371  using Base::Base;
1372  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1373  PatternRewriter &rewriter) const override {
1374  OpOperand *operand =
1375  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1376  if (!operand)
1377  return failure();
1378  unsigned int operandNumber = operand->getOperandNumber();
1379  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1380  VectorType extractSrcType = extractOp.getSourceVectorType();
1381  Location loc = extractOp.getLoc();
1382 
1383  // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1384  if (extractSrcType.getRank() <= 1) {
1385  return failure();
1386  }
1387 
1388  // All following cases are 2d or higher dimensional source vectors.
1389 
1390  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1391  // There is no distribution, this is a broadcast. Simply move the extract
1392  // out of the warp op.
1393  // TODO: This could be optimized. E.g., in case of a scalar result, let
1394  // one lane extract and shuffle the result to all other lanes (same as
1395  // the 1d case).
1396  SmallVector<size_t> newRetIndices;
1397  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1398  rewriter, warpOp, {extractOp.getSource()},
1399  {extractOp.getSourceVectorType()}, newRetIndices);
1400  rewriter.setInsertionPointAfter(newWarpOp);
1401  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1402  // Extract from distributed vector.
1403  Value newExtract = vector::ExtractOp::create(
1404  rewriter, loc, distributedVec, extractOp.getMixedPosition());
1405  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1406  newExtract);
1407  return success();
1408  }
1409 
1410  // Find the distributed dimension. There should be exactly one.
1411  auto distributedType =
1412  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1413  auto yieldedType = cast<VectorType>(operand->get().getType());
1414  int64_t distributedDim = getDistributedDim(yieldedType, distributedType);
1415  assert(distributedDim != -1 && "could not find distributed dimension");
1416  (void)distributedDim;
1417 
1418  // Yield source vector from warp op.
1419  SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1420  for (int i = 0; i < distributedType.getRank(); ++i)
1421  newDistributedShape[i + extractOp.getNumIndices()] =
1422  distributedType.getDimSize(i);
1423  auto newDistributedType =
1424  VectorType::get(newDistributedShape, distributedType.getElementType());
1425  SmallVector<size_t> newRetIndices;
1426  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1427  rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
1428  newRetIndices);
1429  rewriter.setInsertionPointAfter(newWarpOp);
1430  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1431  // Extract from distributed vector.
1432  Value newExtract = vector::ExtractOp::create(rewriter, loc, distributedVec,
1433  extractOp.getMixedPosition());
1434  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1435  newExtract);
1436  return success();
1437  }
1438 };
1439 
1440 /// Pattern to move out vector.extract with a scalar result.
1441 /// Only supports 1-D and 0-D sources for now.
1442 struct WarpOpExtractScalar : public WarpDistributionPattern {
1443  WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1444  PatternBenefit b = 1)
1445  : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {}
1446  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1447  PatternRewriter &rewriter) const override {
1448  OpOperand *operand =
1449  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1450  if (!operand)
1451  return failure();
1452  unsigned int operandNumber = operand->getOperandNumber();
1453  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1454  VectorType extractSrcType = extractOp.getSourceVectorType();
1455  // Only supports 1-D or 0-D sources for now.
1456  if (extractSrcType.getRank() > 1) {
1457  return rewriter.notifyMatchFailure(
1458  extractOp, "only 0-D or 1-D source supported for now");
1459  }
1460  // TODO: Supported shuffle types should be parameterizable, similar to
1461  // `WarpShuffleFromIdxFn`.
1462  if (!extractSrcType.getElementType().isF32() &&
1463  !extractSrcType.getElementType().isInteger(32))
1464  return rewriter.notifyMatchFailure(
1465  extractOp, "only f32/i32 element types are supported");
1466  bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1467  Type elType = extractSrcType.getElementType();
1468  VectorType distributedVecType;
1469  if (!is0dOrVec1Extract) {
1470  assert(extractSrcType.getRank() == 1 &&
1471  "expected that extract src rank is 0 or 1");
1472  if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1473  return failure();
1474  int64_t elementsPerLane =
1475  extractSrcType.getShape()[0] / warpOp.getWarpSize();
1476  distributedVecType = VectorType::get({elementsPerLane}, elType);
1477  } else {
1478  distributedVecType = extractSrcType;
1479  }
1480  // Yield source vector and position (if present) from warp op.
1481  SmallVector<Value> additionalResults{extractOp.getSource()};
1482  SmallVector<Type> additionalResultTypes{distributedVecType};
1483  additionalResults.append(
1484  SmallVector<Value>(extractOp.getDynamicPosition()));
1485  additionalResultTypes.append(
1486  SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1487 
1488  Location loc = extractOp.getLoc();
1489  SmallVector<size_t> newRetIndices;
1490  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1491  rewriter, warpOp, additionalResults, additionalResultTypes,
1492  newRetIndices);
1493  rewriter.setInsertionPointAfter(newWarpOp);
1494  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1495 
1496  // 0d extract: The new warp op broadcasts the source vector to all lanes.
1497  // All lanes extract the scalar.
1498  if (is0dOrVec1Extract) {
1499  Value newExtract;
1500  SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1501  newExtract =
1502  vector::ExtractOp::create(rewriter, loc, distributedVec, indices);
1503  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1504  newExtract);
1505  return success();
1506  }
1507 
1508  int64_t staticPos = extractOp.getStaticPosition()[0];
1509  OpFoldResult pos = ShapedType::isDynamic(staticPos)
1510  ? (newWarpOp->getResult(newRetIndices[1]))
1511  : OpFoldResult(rewriter.getIndexAttr(staticPos));
1512  // 1d extract: Distribute the source vector. One lane extracts and shuffles
1513  // the value to all other lanes.
1514  int64_t elementsPerLane = distributedVecType.getShape()[0];
1515  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1516  // tid of extracting thread: pos / elementsPerLane
1517  Value broadcastFromTid = affine::makeComposedAffineApply(
1518  rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1519  // Extract at position: pos % elementsPerLane
1520  Value newPos =
1521  elementsPerLane == 1
1522  ? arith::ConstantIndexOp::create(rewriter, loc, 0).getResult()
1523  : affine::makeComposedAffineApply(rewriter, loc,
1524  sym0 % elementsPerLane, pos);
1525  Value extracted =
1526  vector::ExtractOp::create(rewriter, loc, distributedVec, newPos);
1527 
1528  // Shuffle the extracted value to all lanes.
1529  Value shuffled = warpShuffleFromIdxFn(
1530  loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1531  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1532  return success();
1533  }
1534 
1535 private:
1536  WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1537 };
1538 
1539 /// Pattern to move out vector.insert with a scalar input.
1540 /// Only supports 1-D and 0-D destinations for now.
1541 struct WarpOpInsertScalar : public WarpDistributionPattern {
1542  using Base::Base;
1543  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1544  PatternRewriter &rewriter) const override {
1545  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1546  if (!operand)
1547  return failure();
1548  unsigned int operandNumber = operand->getOperandNumber();
1549  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1550  VectorType vecType = insertOp.getDestVectorType();
1551  VectorType distrType =
1552  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1553 
1554  // Only supports 1-D or 0-D destinations for now.
1555  if (vecType.getRank() > 1) {
1556  return rewriter.notifyMatchFailure(
1557  insertOp, "only 0-D or 1-D source supported for now");
1558  }
1559 
1560  // Yield destination vector, source scalar and position from warp op.
1561  SmallVector<Value> additionalResults{insertOp.getDest(),
1562  insertOp.getValueToStore()};
1563  SmallVector<Type> additionalResultTypes{
1564  distrType, insertOp.getValueToStore().getType()};
1565  additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1566  additionalResultTypes.append(
1567  SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1568 
1569  Location loc = insertOp.getLoc();
1570  SmallVector<size_t> newRetIndices;
1571  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1572  rewriter, warpOp, additionalResults, additionalResultTypes,
1573  newRetIndices);
1574  rewriter.setInsertionPointAfter(newWarpOp);
1575  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1576  Value newSource = newWarpOp->getResult(newRetIndices[1]);
1577  rewriter.setInsertionPointAfter(newWarpOp);
1578 
1579  OpFoldResult pos;
1580  if (vecType.getRank() != 0) {
1581  int64_t staticPos = insertOp.getStaticPosition()[0];
1582  pos = ShapedType::isDynamic(staticPos)
1583  ? (newWarpOp->getResult(newRetIndices[2]))
1584  : OpFoldResult(rewriter.getIndexAttr(staticPos));
1585  }
1586 
1587  // This condition is always true for 0-d vectors.
1588  if (vecType == distrType) {
1589  Value newInsert;
1590  SmallVector<OpFoldResult> indices;
1591  if (pos) {
1592  indices.push_back(pos);
1593  }
1594  newInsert = vector::InsertOp::create(rewriter, loc, newSource,
1595  distributedVec, indices);
1596  // Broadcast: Simply move the vector.insert op out.
1597  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1598  newInsert);
1599  return success();
1600  }
1601 
1602  // This is a distribution. Only one lane should insert.
1603  int64_t elementsPerLane = distrType.getShape()[0];
1604  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1605  // tid of extracting thread: pos / elementsPerLane
1606  Value insertingLane = affine::makeComposedAffineApply(
1607  rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
1608  // Insert position: pos % elementsPerLane
1610  rewriter, loc, sym0 % elementsPerLane, pos);
1611  Value isInsertingLane =
1612  arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1613  newWarpOp.getLaneid(), insertingLane);
1614  Value newResult =
1615  scf::IfOp::create(
1616  rewriter, loc, isInsertingLane,
1617  /*thenBuilder=*/
1618  [&](OpBuilder &builder, Location loc) {
1619  Value newInsert = vector::InsertOp::create(
1620  builder, loc, newSource, distributedVec, newPos);
1621  scf::YieldOp::create(builder, loc, newInsert);
1622  },
1623  /*elseBuilder=*/
1624  [&](OpBuilder &builder, Location loc) {
1625  scf::YieldOp::create(builder, loc, distributedVec);
1626  })
1627  .getResult(0);
1628  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1629  return success();
1630  }
1631 };
1632 
1633 struct WarpOpInsert : public WarpDistributionPattern {
1634  using Base::Base;
1635  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1636  PatternRewriter &rewriter) const override {
1637  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1638  if (!operand)
1639  return failure();
1640  unsigned int operandNumber = operand->getOperandNumber();
1641  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1642  Location loc = insertOp.getLoc();
1643 
1644  // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1645  if (insertOp.getDestVectorType().getRank() <= 1) {
1646  return failure();
1647  }
1648 
1649  // All following cases are 2d or higher dimensional source vectors.
1650 
1651  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1652  // There is no distribution, this is a broadcast. Simply move the insert
1653  // out of the warp op.
1654  SmallVector<size_t> newRetIndices;
1655  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1656  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1657  {insertOp.getValueToStoreType(), insertOp.getDestVectorType()},
1658  newRetIndices);
1659  rewriter.setInsertionPointAfter(newWarpOp);
1660  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1661  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1662  Value newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1663  distributedDest,
1664  insertOp.getMixedPosition());
1665  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1666  newResult);
1667  return success();
1668  }
1669 
1670  // Find the distributed dimension. There should be exactly one.
1671  auto distrDestType =
1672  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1673  auto yieldedType = cast<VectorType>(operand->get().getType());
1674  int64_t distrDestDim = -1;
1675  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1676  if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1677  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1678  // support distributing multiple dimensions in the future.
1679  assert(distrDestDim == -1 && "found multiple distributed dims");
1680  distrDestDim = i;
1681  }
1682  }
1683  assert(distrDestDim != -1 && "could not find distributed dimension");
1684 
1685  // Compute the distributed source vector type.
1686  VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType());
1687  SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1688  // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1689  // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1690  // insert a smaller vector<3xf32>.
1691  // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1692  // case, one lane will insert the source vector<96xf32>. The other
1693  // lanes will not do anything.
1694  int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1695  if (distrSrcDim >= 0)
1696  distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1697  auto distrSrcType =
1698  VectorType::get(distrSrcShape, distrDestType.getElementType());
1699 
1700  // Yield source and dest vectors from warp op.
1701  SmallVector<size_t> newRetIndices;
1702  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1703  rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()},
1704  {distrSrcType, distrDestType}, newRetIndices);
1705  rewriter.setInsertionPointAfter(newWarpOp);
1706  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1707  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1708 
1709  // Insert into the distributed vector.
1710  Value newResult;
1711  if (distrSrcDim >= 0) {
1712  // Every lane inserts a small piece.
1713  newResult = vector::InsertOp::create(rewriter, loc, distributedSrc,
1714  distributedDest,
1715  insertOp.getMixedPosition());
1716  } else {
1717  // One lane inserts the entire source vector.
1718  int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1719  SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1720  SmallVector<int64_t> newPos = getAsIntegers(pos);
1721  // tid of inserting lane: pos / elementsPerLane
1722  Value insertingLane = arith::ConstantIndexOp::create(
1723  rewriter, loc, newPos[distrDestDim] / elementsPerLane);
1724  Value isInsertingLane =
1725  arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
1726  newWarpOp.getLaneid(), insertingLane);
1727  // Insert position: pos % elementsPerLane
1728  newPos[distrDestDim] %= elementsPerLane;
1729  auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1730  Value newInsert = vector::InsertOp::create(builder, loc, distributedSrc,
1731  distributedDest, newPos);
1732  scf::YieldOp::create(builder, loc, newInsert);
1733  };
1734  auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1735  scf::YieldOp::create(builder, loc, distributedDest);
1736  };
1737  newResult = scf::IfOp::create(rewriter, loc, isInsertingLane,
1738  /*thenBuilder=*/insertingBuilder,
1739  /*elseBuilder=*/nonInsertingBuilder)
1740  .getResult(0);
1741  }
1742 
1743  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1744  return success();
1745  }
1746 };
1747 
1748 /// Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
1749 /// the scf.if is the last operation in the region so that it doesn't
1750 /// change the order of execution. This creates a new scf.if after the
1751 /// WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
1752 /// the "inner" WarpExecuteOnLane0Op. Example:
1753 /// ```
1754 /// gpu.warp_execute_on_lane_0(%laneid)[32] {
1755 /// %payload = ... : vector<32xindex>
1756 /// scf.if %pred {
1757 /// vector.store %payload, %buffer[%idx] : memref<128xindex>,
1758 /// vector<32xindex>
1759 /// }
1760 /// gpu.yield
1761 /// }
1762 /// ```
1763 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
1764 /// %payload = ... : vector<32xindex>
1765 /// gpu.yield %payload : vector<32xindex>
1766 /// }
1767 /// scf.if %pred {
1768 /// gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
1769 /// ^bb0(%arg1: vector<32xindex>):
1770 /// vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
1771 /// }
1772 /// }
1773 /// ```
1774 struct WarpOpScfIfOp : public WarpDistributionPattern {
1775  WarpOpScfIfOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1776  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1777  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1778  PatternRewriter &rewriter) const override {
1779  gpu::YieldOp warpOpYield = warpOp.getTerminator();
1780  // Only pick up `IfOp` if it is the last op in the region.
1781  Operation *lastNode = warpOpYield->getPrevNode();
1782  auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1783  if (!ifOp)
1784  return failure();
1785 
1786  // The current `WarpOp` can yield two types of values:
1787  // 1. Not results of `IfOp`:
1788  // Preserve them in the new `WarpOp`.
1789  // Collect their yield index to remap the usages.
1790  // 2. Results of `IfOp`:
1791  // They are not part of the new `WarpOp` results.
1792  // Map current warp's yield operand index to `IfOp` result idx.
1793  SmallVector<Value> nonIfYieldValues;
1794  SmallVector<unsigned> nonIfYieldIndices;
1795  llvm::SmallDenseMap<unsigned, unsigned> ifResultMapping;
1796  llvm::SmallDenseMap<unsigned, VectorType> ifResultDistTypes;
1797  for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
1798  const unsigned yieldOperandIdx = yieldOperand.getOperandNumber();
1799  if (yieldOperand.get().getDefiningOp() != ifOp.getOperation()) {
1800  nonIfYieldValues.push_back(yieldOperand.get());
1801  nonIfYieldIndices.push_back(yieldOperandIdx);
1802  continue;
1803  }
1804  OpResult ifResult = cast<OpResult>(yieldOperand.get());
1805  const unsigned ifResultIdx = ifResult.getResultNumber();
1806  ifResultMapping[yieldOperandIdx] = ifResultIdx;
1807  // If this `ifOp` result is vector type and it is yielded by the
1808  // `WarpOp`, we keep track the distributed type for this result.
1809  if (!isa<VectorType>(ifResult.getType()))
1810  continue;
1811  VectorType distType =
1812  cast<VectorType>(warpOp.getResult(yieldOperandIdx).getType());
1813  ifResultDistTypes[ifResultIdx] = distType;
1814  }
1815 
1816  // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
1817  // them
1818  auto [escapingValuesThen, escapingValueInputTypesThen,
1819  escapingValueDistTypesThen] =
1820  getInnerRegionEscapingValues(warpOp, ifOp.getThenRegion(),
1821  distributionMapFn);
1822  auto [escapingValuesElse, escapingValueInputTypesElse,
1823  escapingValueDistTypesElse] =
1824  getInnerRegionEscapingValues(warpOp, ifOp.getElseRegion(),
1825  distributionMapFn);
1826  if (llvm::is_contained(escapingValueDistTypesThen, Type{}) ||
1827  llvm::is_contained(escapingValueDistTypesElse, Type{}))
1828  return failure();
1829 
1830  // The new `WarpOp` groups yields values in following order:
1831  // 1. Branch condition
1832  // 2. Escaping values then branch
1833  // 3. Escaping values else branch
1834  // 4. All non-`ifOp` yielded values.
1835  SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition()};
1836  newWarpOpYieldValues.append(escapingValuesThen.begin(),
1837  escapingValuesThen.end());
1838  newWarpOpYieldValues.append(escapingValuesElse.begin(),
1839  escapingValuesElse.end());
1840  SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition().getType()};
1841  newWarpOpDistTypes.append(escapingValueDistTypesThen.begin(),
1842  escapingValueDistTypesThen.end());
1843  newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
1844  escapingValueDistTypesElse.end());
1845 
1846  llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
1847  for (auto [idx, val] :
1848  llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
1849  origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
1850  newWarpOpYieldValues.push_back(val);
1851  newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
1852  }
1853  // Create the new `WarpOp` with the updated yield values and types.
1854  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
1855  rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1856  // `ifOp` returns the result of the inner warp op.
1857  SmallVector<Type> newIfOpDistResTypes;
1858  for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
1859  Type distType = cast<Value>(res).getType();
1860  if (auto vecType = dyn_cast<VectorType>(distType)) {
1861  AffineMap map = distributionMapFn(cast<Value>(res));
1862  // Fallback to affine map if the dist result was not previously recorded
1863  distType = ifResultDistTypes.count(i)
1864  ? ifResultDistTypes[i]
1865  : getDistributedType(vecType, map, warpOp.getWarpSize());
1866  }
1867  newIfOpDistResTypes.push_back(distType);
1868  }
1869  // Create a new `IfOp` outside the new `WarpOp` region.
1870  OpBuilder::InsertionGuard g(rewriter);
1871  rewriter.setInsertionPointAfter(newWarpOp);
1872  auto newIfOp = scf::IfOp::create(
1873  rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
1874  static_cast<bool>(ifOp.thenBlock()),
1875  static_cast<bool>(ifOp.elseBlock()));
1876  auto encloseRegionInWarpOp =
1877  [&](Block *oldIfBranch, Block *newIfBranch,
1878  llvm::SmallSetVector<Value, 32> &escapingValues,
1879  SmallVector<Type> &escapingValueInputTypes,
1880  size_t warpResRangeStart) {
1881  OpBuilder::InsertionGuard g(rewriter);
1882  if (!newIfBranch)
1883  return;
1884  rewriter.setInsertionPointToStart(newIfBranch);
1885  llvm::SmallDenseMap<Value, int64_t> escapeValToBlockArgIndex;
1886  SmallVector<Value> innerWarpInputVals;
1887  SmallVector<Type> innerWarpInputTypes;
1888  for (size_t i = 0; i < escapingValues.size();
1889  ++i, ++warpResRangeStart) {
1890  innerWarpInputVals.push_back(
1891  newWarpOp.getResult(warpResRangeStart));
1892  escapeValToBlockArgIndex[escapingValues[i]] =
1893  innerWarpInputTypes.size();
1894  innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
1895  }
1896  auto innerWarp = WarpExecuteOnLane0Op::create(
1897  rewriter, newWarpOp.getLoc(), newIfOp.getResultTypes(),
1898  newWarpOp.getLaneid(), newWarpOp.getWarpSize(),
1899  innerWarpInputVals, innerWarpInputTypes);
1900 
1901  innerWarp.getWarpRegion().takeBody(*oldIfBranch->getParent());
1902  innerWarp.getWarpRegion().addArguments(
1903  innerWarpInputTypes,
1904  SmallVector<Location>(innerWarpInputTypes.size(), ifOp.getLoc()));
1905 
1906  SmallVector<Value> yieldOperands;
1907  for (Value operand : oldIfBranch->getTerminator()->getOperands())
1908  yieldOperands.push_back(operand);
1909  rewriter.eraseOp(oldIfBranch->getTerminator());
1910 
1911  rewriter.setInsertionPointToEnd(innerWarp.getBody());
1912  gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
1913  rewriter.setInsertionPointAfter(innerWarp);
1914  scf::YieldOp::create(rewriter, ifOp.getLoc(), innerWarp.getResults());
1915 
1916  // Update any users of escaping values that were forwarded to the
1917  // inner `WarpOp`. These values are arguments of the inner `WarpOp`.
1918  innerWarp.walk([&](Operation *op) {
1919  for (OpOperand &operand : op->getOpOperands()) {
1920  auto it = escapeValToBlockArgIndex.find(operand.get());
1921  if (it == escapeValToBlockArgIndex.end())
1922  continue;
1923  operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1924  }
1925  });
1926  mlir::vector::moveScalarUniformCode(innerWarp);
1927  };
1928  encloseRegionInWarpOp(&ifOp.getThenRegion().front(),
1929  &newIfOp.getThenRegion().front(), escapingValuesThen,
1930  escapingValueInputTypesThen, 1);
1931  if (!ifOp.getElseRegion().empty())
1932  encloseRegionInWarpOp(&ifOp.getElseRegion().front(),
1933  &newIfOp.getElseRegion().front(),
1934  escapingValuesElse, escapingValueInputTypesElse,
1935  1 + escapingValuesThen.size());
1936  // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
1937  // result.
1938  for (auto [origIdx, newIdx] : ifResultMapping)
1939  rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
1940  newIfOp.getResult(newIdx), newIfOp);
1941  // Similarly, update any users of the `WarpOp` results that were not
1942  // results of the `IfOp`.
1943  for (auto [origIdx, newIdx] : origToNewYieldIdx)
1944  rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
1945  newWarpOp.getResult(newIdx));
1946  // Remove the original `WarpOp` and `IfOp`, they should not have any uses
1947  // at this point.
1948  rewriter.eraseOp(ifOp);
1949  rewriter.eraseOp(warpOp);
1950  return success();
1951  }
1952 
1953 private:
1954  DistributionMapFn distributionMapFn;
1955 };
1956 
1957 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1958 /// the scf.ForOp is the last operation in the region so that it doesn't
1959 /// change the order of execution. This creates a new scf.for region after the
1960 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1961 /// WarpExecuteOnLane0Op region. Example:
1962 /// ```
1963 /// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1964 /// ...
1965 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1966 /// -> (vector<128xf32>) {
1967 /// ...
1968 /// scf.yield %r : vector<128xf32>
1969 /// }
1970 /// gpu.yield %v1 : vector<128xf32>
1971 /// }
1972 /// ```
1973 /// To:
1974 /// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1975 /// ...
1976 /// gpu.yield %v : vector<128xf32>
1977 /// }
1978 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1979 /// -> (vector<4xf32>) {
1980 /// %iw = gpu.warp_execute_on_lane_0(%laneid)
1981 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1982 /// ^bb0(%arg: vector<128xf32>):
1983 /// ...
1984 /// gpu.yield %ir : vector<128xf32>
1985 /// }
1986 /// scf.yield %iw : vector<4xf32>
1987 /// }
1988 /// ```
1989 struct WarpOpScfForOp : public WarpDistributionPattern {
1990 
1991  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1992  : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1993  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1994  PatternRewriter &rewriter) const override {
1995  gpu::YieldOp warpOpYield = warpOp.getTerminator();
1996  // Only pick up `ForOp` if it is the last op in the region.
1997  Operation *lastNode = warpOpYield->getPrevNode();
1998  auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1999  if (!forOp)
2000  return failure();
2001  // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
2002  // Those Values need to be returned by the new warp op.
2003  auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2004  getInnerRegionEscapingValues(warpOp, forOp.getBodyRegion(),
2005  distributionMapFn);
2006  if (llvm::is_contained(escapingValueDistTypes, Type{}))
2007  return failure();
2008  // `WarpOp` can yield two types of values:
2009  // 1. Values that are not results of the `ForOp`:
2010  // These values must also be yielded by the new `WarpOp`. Also, we need
2011  // to record the index mapping for these values to replace them later.
2012  // 2. Values that are results of the `ForOp`:
2013  // In this case, we record the index mapping between the `WarpOp` result
2014  // index and matching `ForOp` result index.
2015  // Additionally, we keep track of the distributed types for all `ForOp`
2016  // vector results.
2017  SmallVector<Value> nonForYieldedValues;
2018  SmallVector<unsigned> nonForResultIndices;
2019  llvm::SmallDenseMap<unsigned, unsigned> forResultMapping;
2020  llvm::SmallDenseMap<unsigned, VectorType> forResultDistTypes;
2021  for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) {
2022  // Yielded value is not a result of the forOp.
2023  if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
2024  nonForYieldedValues.push_back(yieldOperand.get());
2025  nonForResultIndices.push_back(yieldOperand.getOperandNumber());
2026  continue;
2027  }
2028  OpResult forResult = cast<OpResult>(yieldOperand.get());
2029  unsigned int forResultNumber = forResult.getResultNumber();
2030  forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber;
2031  // If this `ForOp` result is vector type and it is yielded by the
2032  // `WarpOp`, we keep track the distributed type for this result.
2033  if (!isa<VectorType>(forResult.getType()))
2034  continue;
2035  VectorType distType = cast<VectorType>(
2036  warpOp.getResult(yieldOperand.getOperandNumber()).getType());
2037  forResultDistTypes[forResultNumber] = distType;
2038  }
2039 
2040  // Newly created `WarpOp` will yield values in following order:
2041  // 1. All init args of the `ForOp`.
2042  // 2. All escaping values.
2043  // 3. All non-`ForOp` yielded values.
2044  SmallVector<Value> newWarpOpYieldValues;
2045  SmallVector<Type> newWarpOpDistTypes;
2046  for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
2047  newWarpOpYieldValues.push_back(initArg);
2048  // Compute the distributed type for this init arg.
2049  Type distType = initArg.getType();
2050  if (auto vecType = dyn_cast<VectorType>(distType)) {
2051  // If the `ForOp` result corresponds to this init arg is already yielded
2052  // we can get the distributed type from `forResultDistTypes` map.
2053  // Otherwise, we compute it using distributionMapFn.
2054  AffineMap map = distributionMapFn(initArg);
2055  distType = forResultDistTypes.count(i)
2056  ? forResultDistTypes[i]
2057  : getDistributedType(vecType, map, warpOp.getWarpSize());
2058  }
2059  newWarpOpDistTypes.push_back(distType);
2060  }
2061  // Insert escaping values and their distributed types.
2062  newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
2063  escapingValues.begin(), escapingValues.end());
2064  newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
2065  escapingValueDistTypes.begin(),
2066  escapingValueDistTypes.end());
2067  // Next, we insert all non-`ForOp` yielded values and their distributed
2068  // types. We also create a mapping between the non-`ForOp` yielded value
2069  // index and the corresponding new `WarpOp` yield value index (needed to
2070  // update users later).
2071  llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
2072  for (auto [i, v] :
2073  llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
2074  nonForResultMapping[i] = newWarpOpYieldValues.size();
2075  newWarpOpYieldValues.push_back(v);
2076  newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
2077  }
2078  // Create the new `WarpOp` with the updated yield values and types.
2079  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
2080  rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
2081 
2082  // Next, we create a new `ForOp` with the init args yielded by the new
2083  // `WarpOp`.
2084  const unsigned escapingValuesStartIdx =
2085  forOp.getInitArgs().size(); // `ForOp` init args are positioned before
2086  // escaping values in the new `WarpOp`.
2087  SmallVector<Value> newForOpOperands;
2088  for (size_t i = 0; i < escapingValuesStartIdx; ++i)
2089  newForOpOperands.push_back(newWarpOp.getResult(i));
2090 
2091  // Create a new `ForOp` outside the new `WarpOp` region.
2092  OpBuilder::InsertionGuard g(rewriter);
2093  rewriter.setInsertionPointAfter(newWarpOp);
2094  auto newForOp = scf::ForOp::create(
2095  rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
2096  forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
2097  forOp.getUnsignedCmp());
2098  // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
2099  // newly created `ForOp`. This `WarpOp` will contain all ops that were
2100  // contained within the original `ForOp` body.
2101  rewriter.setInsertionPointToStart(newForOp.getBody());
2102 
2103  SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
2104  newForOp.getRegionIterArgs().end());
2105  SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
2106  forOp.getResultTypes().end());
2107  // Escaping values are forwarded to the inner `WarpOp` as its (additional)
2108  // arguments. We keep track of the mapping between these values and their
2109  // argument index in the inner `WarpOp` (to replace users later).
2110  llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
2111  for (size_t i = escapingValuesStartIdx;
2112  i < escapingValuesStartIdx + escapingValues.size(); ++i) {
2113  innerWarpInput.push_back(newWarpOp.getResult(i));
2114  argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
2115  innerWarpInputType.size();
2116  innerWarpInputType.push_back(
2117  escapingValueInputTypes[i - escapingValuesStartIdx]);
2118  }
2119  // Create the inner `WarpOp` with the new input values and types.
2120  auto innerWarp = WarpExecuteOnLane0Op::create(
2121  rewriter, newWarpOp.getLoc(), newForOp.getResultTypes(),
2122  newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput,
2123  innerWarpInputType);
2124 
2125  // Inline the `ForOp` body into the inner `WarpOp` body.
2126  SmallVector<Value> argMapping;
2127  argMapping.push_back(newForOp.getInductionVar());
2128  for (Value args : innerWarp.getBody()->getArguments())
2129  argMapping.push_back(args);
2130 
2131  argMapping.resize(forOp.getBody()->getNumArguments());
2132  SmallVector<Value> yieldOperands;
2133  for (Value operand : forOp.getBody()->getTerminator()->getOperands())
2134  yieldOperands.push_back(operand);
2135 
2136  rewriter.eraseOp(forOp.getBody()->getTerminator());
2137  rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
2138 
2139  // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields
2140  // original `ForOp` results.
2141  rewriter.setInsertionPointToEnd(innerWarp.getBody());
2142  gpu::YieldOp::create(rewriter, innerWarp.getLoc(), yieldOperands);
2143  rewriter.setInsertionPointAfter(innerWarp);
2144  // Insert a scf.yield op at the end of the new `ForOp` body that yields
2145  // the inner `WarpOp` results.
2146  if (!innerWarp.getResults().empty())
2147  scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
2148 
2149  // Update the users of original `WarpOp` results that were coming from the
2150  // original `ForOp` to the corresponding new `ForOp` result.
2151  for (auto [origIdx, newIdx] : forResultMapping)
2152  rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
2153  newForOp.getResult(newIdx), newForOp);
2154  // Similarly, update any users of the `WarpOp` results that were not
2155  // results of the `ForOp`.
2156  for (auto [origIdx, newIdx] : nonForResultMapping)
2157  rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
2158  newWarpOp.getResult(newIdx));
2159  // Remove the original `WarpOp` and `ForOp`, they should not have any uses
2160  // at this point.
2161  rewriter.eraseOp(forOp);
2162  rewriter.eraseOp(warpOp);
2163  // Update any users of escaping values that were forwarded to the
2164  // inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
2165  newForOp.walk([&](Operation *op) {
2166  for (OpOperand &operand : op->getOpOperands()) {
2167  auto it = argIndexMapping.find(operand.get());
2168  if (it == argIndexMapping.end())
2169  continue;
2170  operand.set(innerWarp.getBodyRegion().getArgument(it->second));
2171  }
2172  });
2173 
2174  // Finally, hoist out any now uniform code from the inner `WarpOp`.
2175  mlir::vector::moveScalarUniformCode(innerWarp);
2176  return success();
2177  }
2178 
2179 private:
2180  DistributionMapFn distributionMapFn;
2181 };
2182 
2183 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
2184 /// The vector is reduced in parallel. Currently limited to vector size
2185 /// matching the warpOp size. E.g.:
2186 /// ```
2187 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
2188 /// %0 = "some_def"() : () -> (vector<32xf32>)
2189 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
2190 /// gpu.yield %1 : f32
2191 /// }
2192 /// ```
2193 /// is lowered to:
2194 /// ```
2195 /// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
2196 /// %1 = "some_def"() : () -> (vector<32xf32>)
2197 /// gpu.yield %1 : vector<32xf32>
2198 /// }
2199 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
2200 /// %r = ("warp.reduction %a")
2201 /// ```
2202 struct WarpOpReduction : public WarpDistributionPattern {
2203  WarpOpReduction(MLIRContext *context,
2204  DistributedReductionFn distributedReductionFn,
2205  PatternBenefit benefit = 1)
2206  : WarpDistributionPattern(context, benefit),
2207  distributedReductionFn(std::move(distributedReductionFn)) {}
2208 
2209  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
2210  PatternRewriter &rewriter) const override {
2211  OpOperand *yieldOperand =
2212  getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
2213  if (!yieldOperand)
2214  return failure();
2215 
2216  auto reductionOp =
2217  cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
2218  auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
2219  // Only rank 1 vectors supported.
2220  if (vectorType.getRank() != 1)
2221  return rewriter.notifyMatchFailure(
2222  warpOp, "Only rank 1 reductions can be distributed.");
2223  // Only warp_size-sized vectors supported.
2224  if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
2225  return rewriter.notifyMatchFailure(
2226  warpOp, "Reduction vector dimension must match was size.");
2227  if (!reductionOp.getType().isIntOrFloat())
2228  return rewriter.notifyMatchFailure(
2229  warpOp, "Reduction distribution currently only supports floats and "
2230  "integer types.");
2231 
2232  int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
2233  // Return vector that will be reduced from the WarpExecuteOnLane0Op.
2234  unsigned operandIndex = yieldOperand->getOperandNumber();
2235  SmallVector<Value> yieldValues = {reductionOp.getVector()};
2236  SmallVector<Type> retTypes = {
2237  VectorType::get({numElements}, reductionOp.getType())};
2238  if (reductionOp.getAcc()) {
2239  yieldValues.push_back(reductionOp.getAcc());
2240  retTypes.push_back(reductionOp.getAcc().getType());
2241  }
2242  SmallVector<size_t> newRetIndices;
2243  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
2244  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
2245  rewriter.setInsertionPointAfter(newWarpOp);
2246 
2247  // Obtain data to reduce for a single lane.
2248  Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
2249  // Distribute and reduce across threads.
2250  Value fullReduce =
2251  distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
2252  reductionOp.getKind(), newWarpOp.getWarpSize());
2253  if (reductionOp.getAcc()) {
2254  fullReduce = vector::makeArithReduction(
2255  rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
2256  newWarpOp.getResult(newRetIndices[1]));
2257  }
2258  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
2259  return success();
2260  }
2261 
2262 private:
2263  DistributedReductionFn distributedReductionFn;
2264 };
2265 
2266 } // namespace
2267 
2271  patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
2272 }
2273 
2274 void mlir::vector::populateDistributeTransferWriteOpPatterns(
2275  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2276  unsigned maxNumElementsToExtract, PatternBenefit benefit) {
2277  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
2278  maxNumElementsToExtract, benefit);
2279 }
2280 
2281 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
2282  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
2283  const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
2284  PatternBenefit readBenefit) {
2285  patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
2286  patterns
2287  .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
2288  WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
2289  WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask,
2290  WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>(
2291  patterns.getContext(), benefit);
2292  patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
2293  benefit);
2294  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
2295  benefit);
2296  patterns.add<WarpOpScfIfOp>(patterns.getContext(), distributionMapFn,
2297  benefit);
2298 }
2299 
2300 void mlir::vector::populateDistributeReduction(
2302  const DistributedReductionFn &distributedReductionFn,
2303  PatternBenefit benefit) {
2304  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
2305  benefit);
2306 }
2307 
2308 /// Helper to know if an op can be hoisted out of the region.
2309 static bool canBeHoisted(Operation *op,
2310  function_ref<bool(Value)> definedOutside) {
2311  return llvm::all_of(op->getOperands(), definedOutside) &&
2312  isMemoryEffectFree(op) && op->getNumRegions() == 0;
2313 }
2314 
2315 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
2316  Block *body = warpOp.getBody();
2317 
2318  // Keep track of the ops we want to hoist.
2319  llvm::SmallSetVector<Operation *, 8> opsToMove;
2320 
2321  // Helper to check if a value is or will be defined outside of the region.
2322  auto isDefinedOutsideOfBody = [&](Value value) {
2323  auto *definingOp = value.getDefiningOp();
2324  return (definingOp && opsToMove.count(definingOp)) ||
2325  warpOp.isDefinedOutsideOfRegion(value);
2326  };
2327 
2328  // Do not use walk here, as we do not want to go into nested regions and hoist
2329  // operations from there.
2330  for (auto &op : body->without_terminator()) {
2331  bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
2332  return isa<VectorType>(result.getType());
2333  });
2334  if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
2335  opsToMove.insert(&op);
2336  }
2337 
2338  // Move all the ops marked as uniform outside of the region.
2339  for (Operation *op : opsToMove)
2340  op->moveBefore(warpOp);
2341 }
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static int getDistributedDim(VectorType sequentialType, VectorType distributedType)
Given a sequential and distributed vector type, returns the distributed dimension.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:411
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:341
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:371
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:50
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:554
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:710
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:622
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:208
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:39
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1276
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2787
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:345
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:715
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:43
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:629
This represents an operation in an abstracted form, suitable for use with the builder APIs.