MLIR  19.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/AffineExpr.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include <numeric>
21 #include <utility>
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 /// Currently the distribution map is implicit based on the vector shape. In the
27 /// future it will be part of the op.
28 /// Example:
29 /// ```
30 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
31 /// ...
32 /// vector.yield %3 : vector<32x16x64xf32>
33 /// }
34 /// ```
35 /// Would have an implicit map of:
36 /// `(d0, d1, d2) -> (d0, d2)`
37 static AffineMap calculateImplicitMap(VectorType sequentialType,
38  VectorType distributedType) {
40  perm.reserve(1);
41  // Check which dimensions of the sequential type are different than the
42  // dimensions of the distributed type to know the distributed dimensions. Then
43  // associate each distributed dimension to an ID in order.
44  for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
45  if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
46  perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
47  }
48  auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
49  distributedType.getContext());
50  return map;
51 }
52 
53 namespace {
54 
55 /// Helper struct to create the load / store operations that permit transit
56 /// through the parallel / sequential and the sequential / parallel boundaries
57 /// when performing `rewriteWarpOpToScfFor`.
58 ///
59 /// The vector distribution dimension is inferred from the vector types.
60 struct DistributedLoadStoreHelper {
61  DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
62  Value laneId, Value zero)
63  : sequentialVal(sequentialVal), distributedVal(distributedVal),
64  laneId(laneId), zero(zero) {
65  sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
66  distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
67  if (sequentialVectorType && distributedVectorType)
68  distributionMap =
69  calculateImplicitMap(sequentialVectorType, distributedVectorType);
70  }
71 
72  Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
73  int64_t distributedSize = distributedVectorType.getDimSize(index);
75  return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
76  ArrayRef<Value>{laneId});
77  }
78 
79  /// Create a store during the process of distributing the
80  /// `vector.warp_execute_on_thread_0` op.
81  /// Vector distribution assumes the following convention regarding the
82  /// temporary buffers that are created to transition values. This **must**
83  /// be properly specified in the `options.warpAllocationFn`:
84  /// 1. scalars of type T transit through a memref<1xT>.
85  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
86  Operation *buildStore(RewriterBase &b, Location loc, Value val,
87  Value buffer) {
88  assert((val == distributedVal || val == sequentialVal) &&
89  "Must store either the preregistered distributed or the "
90  "preregistered sequential value.");
91  // Scalar case can directly use memref.store.
92  if (!isa<VectorType>(val.getType()))
93  return b.create<memref::StoreOp>(loc, val, buffer, zero);
94 
95  // Vector case must use vector::TransferWriteOp which will later lower to
96  // vector.store of memref.store depending on further lowerings.
97  int64_t rank = sequentialVectorType.getRank();
98  SmallVector<Value> indices(rank, zero);
99  if (val == distributedVal) {
100  for (auto dimExpr : distributionMap.getResults()) {
101  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
102  indices[index] = buildDistributedOffset(b, loc, index);
103  }
104  }
105  SmallVector<bool> inBounds(indices.size(), true);
106  return b.create<vector::TransferWriteOp>(
107  loc, val, buffer, indices,
108  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
109  }
110 
111  /// Create a load during the process of distributing the
112  /// `vector.warp_execute_on_thread_0` op.
113  /// Vector distribution assumes the following convention regarding the
114  /// temporary buffers that are created to transition values. This **must**
115  /// be properly specified in the `options.warpAllocationFn`:
116  /// 1. scalars of type T transit through a memref<1xT>.
117  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
118  ///
119  /// When broadcastMode is true, the load is not distributed to account for
120  /// the broadcast semantics of the `vector.warp_execute_on_lane_0` op.
121  ///
122  /// Example:
123  ///
124  /// ```
125  /// %r = vector.warp_execute_on_lane_0(...) -> (f32) {
126  /// vector.yield %cst : f32
127  /// }
128  /// // Both types are f32. The constant %cst is broadcasted to all lanes.
129  /// ```
130  /// This behavior described in more detail in the documentation of the op.
131  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
132 
133  // Scalar case can directly use memref.store.
134  if (!isa<VectorType>(type))
135  return b.create<memref::LoadOp>(loc, buffer, zero);
136 
137  // Other cases must be vector atm.
138  // Vector case must use vector::TransferReadOp which will later lower to
139  // vector.read of memref.read depending on further lowerings.
140  assert((type == distributedVectorType || type == sequentialVectorType) &&
141  "Must store either the preregistered distributed or the "
142  "preregistered sequential type.");
143  SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
144  if (type == distributedVectorType) {
145  for (auto dimExpr : distributionMap.getResults()) {
146  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
147  indices[index] = buildDistributedOffset(b, loc, index);
148  }
149  }
150  SmallVector<bool> inBounds(indices.size(), true);
151  return b.create<vector::TransferReadOp>(
152  loc, cast<VectorType>(type), buffer, indices,
153  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
154  }
155 
156  Value sequentialVal, distributedVal, laneId, zero;
157  VectorType sequentialVectorType, distributedVectorType;
158  AffineMap distributionMap;
159 };
160 
161 } // namespace
162 
163 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
164 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
165  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
166  ValueRange newYieldedValues, TypeRange newReturnTypes) {
167  // Create a new op before the existing one, with the extra operands.
168  OpBuilder::InsertionGuard g(rewriter);
169  rewriter.setInsertionPoint(warpOp);
170  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
171  warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
172  warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
173 
174  Region &opBody = warpOp.getBodyRegion();
175  Region &newOpBody = newWarpOp.getBodyRegion();
176  Block &newOpFirstBlock = newOpBody.front();
177  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
178  rewriter.eraseBlock(&newOpFirstBlock);
179  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
180  "expected WarpOp with single block");
181 
182  auto yield =
183  cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
184 
185  rewriter.modifyOpInPlace(
186  yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
187  return newWarpOp;
188 }
189 
190 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
191 /// `indices` return the index of each new output.
192 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
193  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
194  ValueRange newYieldedValues, TypeRange newReturnTypes,
195  llvm::SmallVector<size_t> &indices) {
196  SmallVector<Type> types(warpOp.getResultTypes().begin(),
197  warpOp.getResultTypes().end());
198  auto yield = cast<vector::YieldOp>(
199  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
200  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
201  yield.getOperands().end());
202  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
203  if (yieldValues.insert(std::get<0>(newRet))) {
204  types.push_back(std::get<1>(newRet));
205  indices.push_back(yieldValues.size() - 1);
206  } else {
207  // If the value already exit the region don't create a new output.
208  for (auto [idx, yieldOperand] :
209  llvm::enumerate(yieldValues.getArrayRef())) {
210  if (yieldOperand == std::get<0>(newRet)) {
211  indices.push_back(idx);
212  break;
213  }
214  }
215  }
216  }
217  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
218  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
219  rewriter, warpOp, yieldValues.getArrayRef(), types);
220  rewriter.replaceOp(warpOp,
221  newWarpOp.getResults().take_front(warpOp.getNumResults()));
222  return newWarpOp;
223 }
224 
225 /// Helper to know if an op can be hoisted out of the region.
226 static bool canBeHoisted(Operation *op,
227  function_ref<bool(Value)> definedOutside) {
228  return llvm::all_of(op->getOperands(), definedOutside) &&
229  isMemoryEffectFree(op) && op->getNumRegions() == 0;
230 }
231 
232 /// Return a value yielded by `warpOp` which statifies the filter lamdba
233 /// condition and is not dead.
234 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
235  const std::function<bool(Operation *)> &fn) {
236  auto yield = cast<vector::YieldOp>(
237  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
238  for (OpOperand &yieldOperand : yield->getOpOperands()) {
239  Value yieldValues = yieldOperand.get();
240  Operation *definedOp = yieldValues.getDefiningOp();
241  if (definedOp && fn(definedOp)) {
242  if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
243  return &yieldOperand;
244  }
245  }
246  return {};
247 }
248 
249 // Clones `op` into a new operation that takes `operands` and returns
250 // `resultTypes`.
252  Location loc, Operation *op,
253  ArrayRef<Value> operands,
254  ArrayRef<Type> resultTypes) {
255  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
256  op->getAttrs());
257  return rewriter.create(res);
258 }
259 
260 namespace {
261 
262 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
263 /// thread `laneId` executes the entirety of the computation.
264 ///
265 /// After the transformation:
266 /// - the IR within the scf.if op can be thought of as executing sequentially
267 /// (from the point of view of threads along `laneId`).
268 /// - the IR outside of the scf.if op can be thought of as executing in
269 /// parallel (from the point of view of threads along `laneId`).
270 ///
271 /// Values that need to transit through the parallel / sequential and the
272 /// sequential / parallel boundaries do so via reads and writes to a temporary
273 /// memory location.
274 ///
275 /// The transformation proceeds in multiple steps:
276 /// 1. Create the scf.if op.
277 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
278 /// within the scf.if to transit the values captured from above.
279 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
280 /// consistent within the scf.if.
281 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
282 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to
283 /// transit the values returned by the op.
284 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
285 /// consistent after the scf.if.
286 /// 7. Perform late cleanups.
287 ///
288 /// All this assumes the vector distribution occurs along the most minor
289 /// distributed vector dimension.
290 struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
291  WarpOpToScfIfPattern(MLIRContext *context,
293  PatternBenefit benefit = 1)
294  : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
295  options(options) {}
296 
297  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
298  PatternRewriter &rewriter) const override {
299  assert(warpOp.getBodyRegion().hasOneBlock() &&
300  "expected WarpOp with single block");
301  Block *warpOpBody = &warpOp.getBodyRegion().front();
302  Location loc = warpOp.getLoc();
303 
304  // Passed all checks. Start rewriting.
305  OpBuilder::InsertionGuard g(rewriter);
306  rewriter.setInsertionPoint(warpOp);
307 
308  // Step 1: Create scf.if op.
309  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
310  Value isLane0 = rewriter.create<arith::CmpIOp>(
311  loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
312  auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
313  /*withElseRegion=*/false);
314  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
315 
316  // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
317  // reads within the scf.if to transit the values captured from above.
318  SmallVector<Value> bbArgReplacements;
319  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
320  Value sequentialVal = warpOpBody->getArgument(it.index());
321  Value distributedVal = it.value();
322  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
323  warpOp.getLaneid(), c0);
324 
325  // Create buffer before the ifOp.
326  rewriter.setInsertionPoint(ifOp);
327  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
328  sequentialVal.getType());
329  // Store distributed vector into buffer, before the ifOp.
330  helper.buildStore(rewriter, loc, distributedVal, buffer);
331  // Load sequential vector from buffer, inside the ifOp.
332  rewriter.setInsertionPointToStart(ifOp.thenBlock());
333  bbArgReplacements.push_back(
334  helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
335  }
336 
337  // Step 3. Insert sync after all the stores and before all the loads.
338  if (!warpOp.getArgs().empty()) {
339  rewriter.setInsertionPoint(ifOp);
340  options.warpSyncronizationFn(loc, rewriter, warpOp);
341  }
342 
343  // Step 4. Move body of warpOp to ifOp.
344  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
345 
346  // Step 5. Insert appropriate writes within scf.if and reads after the
347  // scf.if to transit the values returned by the op.
348  // TODO: at this point, we can reuse the shared memory from previous
349  // buffers.
350  SmallVector<Value> replacements;
351  auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
352  Location yieldLoc = yieldOp.getLoc();
353  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
354  Value sequentialVal = it.value();
355  Value distributedVal = warpOp->getResult(it.index());
356  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
357  warpOp.getLaneid(), c0);
358 
359  // Create buffer before the ifOp.
360  rewriter.setInsertionPoint(ifOp);
361  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
362  sequentialVal.getType());
363 
364  // Store yielded value into buffer, inside the ifOp, before the
365  // terminator.
366  rewriter.setInsertionPoint(yieldOp);
367  helper.buildStore(rewriter, loc, sequentialVal, buffer);
368 
369  // Load distributed value from buffer, after the warpOp.
370  rewriter.setInsertionPointAfter(ifOp);
371  // Result type and yielded value type are the same. This is a broadcast.
372  // E.g.:
373  // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
374  // vector.yield %cst : f32
375  // }
376  // Both types are f32. The constant %cst is broadcasted to all lanes.
377  // This is described in more detail in the documentation of the op.
378  replacements.push_back(
379  helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
380  }
381 
382  // Step 6. Insert sync after all the stores and before all the loads.
383  if (!yieldOp.getOperands().empty()) {
384  rewriter.setInsertionPointAfter(ifOp);
385  options.warpSyncronizationFn(loc, rewriter, warpOp);
386  }
387 
388  // Step 7. Delete terminator and add empty scf.yield.
389  rewriter.eraseOp(yieldOp);
390  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
391  rewriter.create<scf::YieldOp>(yieldLoc);
392 
393  // Compute replacements for WarpOp results.
394  rewriter.replaceOp(warpOp, replacements);
395 
396  return success();
397  }
398 
399 private:
401 };
402 
403 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
404 /// op with the proper return type.
405 /// The new write op is updated to write the result of the new warp execute op.
406 /// The old `writeOp` is deleted.
407 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
408  WarpExecuteOnLane0Op warpOp,
409  vector::TransferWriteOp writeOp,
410  VectorType targetType,
411  VectorType maybeMaskType) {
412  assert(writeOp->getParentOp() == warpOp &&
413  "write must be nested immediately under warp");
414  OpBuilder::InsertionGuard g(rewriter);
415  SmallVector<size_t> newRetIndices;
416  WarpExecuteOnLane0Op newWarpOp;
417  if (maybeMaskType) {
419  rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
420  TypeRange{targetType, maybeMaskType}, newRetIndices);
421  } else {
423  rewriter, warpOp, ValueRange{{writeOp.getVector()}},
424  TypeRange{targetType}, newRetIndices);
425  }
426  rewriter.setInsertionPointAfter(newWarpOp);
427  auto newWriteOp =
428  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
429  rewriter.eraseOp(writeOp);
430  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
431  if (maybeMaskType)
432  newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
433  return newWriteOp;
434 }
435 
436 /// Return the distributed vector type based on the original type and the
437 /// distribution map. The map is expected to have a dimension equal to the
438 /// original type rank and should be a projection where the results are the
439 /// distributed dimensions. The number of results should be equal to the number
440 /// of warp sizes which is currently limited to 1.
441 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
442 /// and a warp size of 16 would distribute the second dimension (associated to
443 /// d1) and return vector<16x2x64>
444 static VectorType getDistributedType(VectorType originalType, AffineMap map,
445  int64_t warpSize) {
446  SmallVector<int64_t> targetShape(originalType.getShape().begin(),
447  originalType.getShape().end());
448  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
449  unsigned position = map.getDimPosition(i);
450  if (targetShape[position] % warpSize != 0) {
451  if (warpSize % targetShape[position] != 0) {
452  return VectorType();
453  }
454  warpSize /= targetShape[position];
455  targetShape[position] = 1;
456  continue;
457  }
458  targetShape[position] = targetShape[position] / warpSize;
459  warpSize = 1;
460  break;
461  }
462  if (warpSize != 1) {
463  return VectorType();
464  }
465  VectorType targetType =
466  VectorType::get(targetShape, originalType.getElementType());
467  return targetType;
468 }
469 
470 /// Distribute transfer_write ops based on the affine map returned by
471 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
472 /// will not be distributed (it should be less than the warp size).
473 ///
474 /// Example:
475 /// ```
476 /// %0 = vector.warp_execute_on_lane_0(%id){
477 /// ...
478 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
479 /// vector.yield
480 /// }
481 /// ```
482 /// To
483 /// ```
484 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
485 /// ...
486 /// vector.yield %v : vector<32xf32>
487 /// }
488 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
489 struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
490  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
491  unsigned maxNumElementsToExtract, PatternBenefit b = 1)
492  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
493  distributionMapFn(std::move(fn)),
494  maxNumElementsToExtract(maxNumElementsToExtract) {}
495 
496  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
497  /// are multiples of the distribution ratio are supported at the moment.
498  LogicalResult tryDistributeOp(RewriterBase &rewriter,
499  vector::TransferWriteOp writeOp,
500  WarpExecuteOnLane0Op warpOp) const {
501  VectorType writtenVectorType = writeOp.getVectorType();
502 
503  // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
504  // to separate it from the rest.
505  if (writtenVectorType.getRank() == 0)
506  return failure();
507 
508  // 2. Compute the distributed type.
509  AffineMap map = distributionMapFn(writeOp.getVector());
510  VectorType targetType =
511  getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
512  if (!targetType)
513  return failure();
514 
515  // 2.5 Compute the distributed type for the new mask;
516  VectorType maskType;
517  if (writeOp.getMask()) {
518  // TODO: Distribution of masked writes with non-trivial permutation maps
519  // requires the distribution of the mask to elementwise match the
520  // distribution of the permuted written vector. Currently the details
521  // of which lane is responsible for which element is captured strictly
522  // by shape information on the warp op, and thus requires materializing
523  // the permutation in IR.
524  if (!writeOp.getPermutationMap().isMinorIdentity())
525  return failure();
526  maskType =
527  getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
528  }
529 
530  // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
531  // the rest.
532  vector::TransferWriteOp newWriteOp =
533  cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
534 
535  // 4. Reindex the write using the distribution map.
536  auto newWarpOp =
537  newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
538 
539  // Delinearize the lane id based on the way threads are divided across the
540  // vector. To get the number of threads per vector dimension, divide the
541  // sequential size by the distributed size along each dim.
542  rewriter.setInsertionPoint(newWriteOp);
543  SmallVector<OpFoldResult> delinearizedIdSizes;
544  for (auto [seqSize, distSize] :
545  llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
546  assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
547  delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
548  }
549  SmallVector<Value> delinearized;
550  if (map.getNumResults() > 1) {
551  delinearized = rewriter
552  .create<mlir::affine::AffineDelinearizeIndexOp>(
553  newWarpOp.getLoc(), newWarpOp.getLaneid(),
554  delinearizedIdSizes)
555  .getResults();
556  } else {
557  // If there is only one map result, we can elide the delinearization
558  // op and use the lane id directly.
559  delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
560  }
561 
562  AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
563  Location loc = newWriteOp.getLoc();
564  SmallVector<Value> indices(newWriteOp.getIndices().begin(),
565  newWriteOp.getIndices().end());
566  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
567  AffineExpr d0, d1;
568  bindDims(newWarpOp.getContext(), d0, d1);
569  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
570  if (!indexExpr)
571  continue;
572  unsigned indexPos = indexExpr.getPosition();
573  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
574  Value laneId = delinearized[vectorPos];
575  auto scale =
576  rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
577  indices[indexPos] = affine::makeComposedAffineApply(
578  rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
579  }
580  newWriteOp.getIndicesMutable().assign(indices);
581 
582  return success();
583  }
584 
585  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
586  LogicalResult tryExtractOp(RewriterBase &rewriter,
587  vector::TransferWriteOp writeOp,
588  WarpExecuteOnLane0Op warpOp) const {
589  Location loc = writeOp.getLoc();
590  VectorType vecType = writeOp.getVectorType();
591 
592  if (vecType.getNumElements() > maxNumElementsToExtract) {
593  return rewriter.notifyMatchFailure(
594  warpOp,
595  llvm::formatv(
596  "writes more elements ({0}) than allowed to extract ({1})",
597  vecType.getNumElements(), maxNumElementsToExtract));
598  }
599 
600  // Do not process warp ops that contain only TransferWriteOps.
601  if (llvm::all_of(warpOp.getOps(),
602  llvm::IsaPred<vector::TransferWriteOp, vector::YieldOp>))
603  return failure();
604 
605  SmallVector<Value> yieldValues = {writeOp.getVector()};
606  SmallVector<Type> retTypes = {vecType};
607  SmallVector<size_t> newRetIndices;
608  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
609  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
610  rewriter.setInsertionPointAfter(newWarpOp);
611 
612  // Create a second warp op that contains only writeOp.
613  auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
614  loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
615  Block &body = secondWarpOp.getBodyRegion().front();
616  rewriter.setInsertionPointToStart(&body);
617  auto newWriteOp =
618  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
619  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
620  rewriter.eraseOp(writeOp);
621  rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
622  return success();
623  }
624 
625  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
626  PatternRewriter &rewriter) const override {
627  auto yield = cast<vector::YieldOp>(
628  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
629  Operation *lastNode = yield->getPrevNode();
630  auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
631  if (!writeOp)
632  return failure();
633 
634  Value maybeMask = writeOp.getMask();
635  if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
636  return writeOp.getVector() == value ||
637  (maybeMask && maybeMask == value) ||
638  warpOp.isDefinedOutsideOfRegion(value);
639  }))
640  return failure();
641 
642  if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
643  return success();
644 
645  // Masked writes not supported for extraction.
646  if (writeOp.getMask())
647  return failure();
648 
649  if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
650  return success();
651 
652  return failure();
653  }
654 
655 private:
656  DistributionMapFn distributionMapFn;
657  unsigned maxNumElementsToExtract = 1;
658 };
659 
660 /// Sink out elementwise op feeding into a warp op yield.
661 /// ```
662 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
663 /// ...
664 /// %3 = arith.addf %1, %2 : vector<32xf32>
665 /// vector.yield %3 : vector<32xf32>
666 /// }
667 /// ```
668 /// To
669 /// ```
670 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
671 /// vector<1xf32>, vector<1xf32>) {
672 /// ...
673 /// %4 = arith.addf %2, %3 : vector<32xf32>
674 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
675 /// vector<32xf32>
676 /// }
677 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
678 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
680  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
681  PatternRewriter &rewriter) const override {
682  OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
684  });
685  if (!yieldOperand)
686  return failure();
687 
688  Operation *elementWise = yieldOperand->get().getDefiningOp();
689  unsigned operandIndex = yieldOperand->getOperandNumber();
690  Value distributedVal = warpOp.getResult(operandIndex);
691  SmallVector<Value> yieldValues;
692  SmallVector<Type> retTypes;
693  Location loc = warpOp.getLoc();
694  for (OpOperand &operand : elementWise->getOpOperands()) {
695  Type targetType;
696  if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
697  // If the result type is a vector, the operands must also be vectors.
698  auto operandType = cast<VectorType>(operand.get().getType());
699  targetType =
700  VectorType::get(vecType.getShape(), operandType.getElementType());
701  } else {
702  auto operandType = operand.get().getType();
703  assert(!isa<VectorType>(operandType) &&
704  "unexpected yield of vector from op with scalar result type");
705  targetType = operandType;
706  }
707  retTypes.push_back(targetType);
708  yieldValues.push_back(operand.get());
709  }
710  SmallVector<size_t> newRetIndices;
711  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
712  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
713  rewriter.setInsertionPointAfter(newWarpOp);
714  SmallVector<Value> newOperands(elementWise->getOperands().begin(),
715  elementWise->getOperands().end());
716  for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
717  newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
718  }
719  OpBuilder::InsertionGuard g(rewriter);
720  rewriter.setInsertionPointAfter(newWarpOp);
722  rewriter, loc, elementWise, newOperands,
723  {newWarpOp.getResult(operandIndex).getType()});
724  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
725  newOp->getResult(0));
726  return success();
727  }
728 };
729 
730 /// Sink out splat constant op feeding into a warp op yield.
731 /// ```
732 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
733 /// ...
734 /// %cst = arith.constant dense<2.0> : vector<32xf32>
735 /// vector.yield %cst : vector<32xf32>
736 /// }
737 /// ```
738 /// To
739 /// ```
740 /// vector.warp_execute_on_lane_0(%arg0 {
741 /// ...
742 /// }
743 /// %0 = arith.constant dense<2.0> : vector<1xf32>
744 struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
746  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
747  PatternRewriter &rewriter) const override {
748  OpOperand *yieldOperand =
749  getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
750  if (!yieldOperand)
751  return failure();
752  auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
753  auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
754  if (!dense)
755  return failure();
756  // Notify the rewriter that the warp op is changing (see the comment on
757  // the WarpOpTransferRead pattern).
758  rewriter.startOpModification(warpOp);
759  unsigned operandIndex = yieldOperand->getOperandNumber();
760  Attribute scalarAttr = dense.getSplatValue<Attribute>();
761  auto newAttr = DenseElementsAttr::get(
762  cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
763  Location loc = warpOp.getLoc();
764  rewriter.setInsertionPointAfter(warpOp);
765  Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
766  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
767  rewriter.finalizeOpModification(warpOp);
768  return success();
769  }
770 };
771 
772 /// Delinearize the given `laneId` into multiple dimensions, where each
773 /// dimension's size is determined by `originalShape` and `distributedShape`
774 /// together. This function expects the total numbers of threads needed for
775 /// distribution is equal to `warpSize`. Returns true and updates
776 /// `delinearizedIds` if so.
777 bool delinearizeLaneId(OpBuilder &builder, Location loc,
778  ArrayRef<int64_t> originalShape,
779  ArrayRef<int64_t> distributedShape, int64_t warpSize,
780  Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
781  // If the original shape and the distributed shape is the same, we don't
782  // distribute at all--every thread is handling the whole. For such case, we
783  // should not rely on lane IDs later. So just return an empty lane ID vector.
784  if (originalShape == distributedShape) {
785  delinearizedIds.clear();
786  return true;
787  }
788 
789  SmallVector<int64_t> sizes;
790  for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
791  if (large % small != 0)
792  return false;
793  sizes.push_back(large / small);
794  }
795  if (std::accumulate(sizes.begin(), sizes.end(), 1,
796  std::multiplies<int64_t>()) != warpSize)
797  return false;
798 
799  AffineExpr s0, s1;
800  bindSymbols(builder.getContext(), s0, s1);
801 
802  int64_t usedThreads = 1;
803 
804  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
805  delinearizedIds.assign(sizes.size(), zero);
806 
807  for (int i = sizes.size() - 1; i >= 0; --i) {
808  usedThreads *= sizes[i];
809  if (usedThreads == warpSize) {
810  // We've used up all available threads. Don't need to perform modulo
811  // anymore. And we can stop the calculation for further dimensions.
812  delinearizedIds[i] = laneId;
813  break;
814  }
815  delinearizedIds[i] =
816  affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
818  builder, loc, s0.floorDiv(usedThreads), {laneId});
819  }
820  return true;
821 }
822 
823 /// Sink out transfer_read op feeding into a warp op yield.
824 /// ```
825 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
826 /// ...
827 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
828 // vector<32xf32>
829 /// vector.yield %2 : vector<32xf32>
830 /// }
831 /// ```
832 /// To
833 /// ```
834 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
835 /// vector<1xf32>, vector<1xf32>) {
836 /// ...
837 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
838 /// vector<32xf32> vector.yield %2 : vector<32xf32>
839 /// }
840 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
841 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
843  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
844  PatternRewriter &rewriter) const override {
845  // Try to find a distributable yielded read. Note that this pattern can
846  // still fail at the end after distribution, in which case this might have
847  // missed another distributable read.
848  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
849  // Don't duplicate transfer_read ops when distributing.
850  return isa<vector::TransferReadOp>(op) && op->hasOneUse();
851  });
852  if (!operand)
853  return rewriter.notifyMatchFailure(
854  warpOp, "warp result is not a vector.transfer_read op");
855  auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
856 
857  // Source must be defined outside of the region.
858  if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
859  return rewriter.notifyMatchFailure(
860  read, "source must be defined outside of the region");
861 
862  unsigned operandIndex = operand->getOperandNumber();
863  Value distributedVal = warpOp.getResult(operandIndex);
864 
865  SmallVector<Value, 4> indices(read.getIndices().begin(),
866  read.getIndices().end());
867  auto sequentialType = cast<VectorType>(read.getResult().getType());
868  auto distributedType = cast<VectorType>(distributedVal.getType());
869  AffineMap map = calculateImplicitMap(sequentialType, distributedType);
870  AffineMap indexMap = map.compose(read.getPermutationMap());
871 
872  // Try to delinearize the lane ID to match the rank expected for
873  // distribution.
874  SmallVector<Value> delinearizedIds;
875  if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
876  distributedType.getShape(), warpOp.getWarpSize(),
877  warpOp.getLaneid(), delinearizedIds)) {
878  return rewriter.notifyMatchFailure(
879  read, "cannot delinearize lane ID for distribution");
880  }
881  assert(!delinearizedIds.empty() || map.getNumResults() == 0);
882 
883  // Distribute indices and the mask (if present).
884  OpBuilder::InsertionGuard g(rewriter);
885  SmallVector<Value> additionalResults(indices.begin(), indices.end());
886  SmallVector<Type> additionalResultTypes(indices.size(),
887  rewriter.getIndexType());
888  additionalResults.push_back(read.getPadding());
889  additionalResultTypes.push_back(read.getPadding().getType());
890 
891  bool hasMask = false;
892  if (read.getMask()) {
893  hasMask = true;
894  // TODO: Distribution of masked reads with non-trivial permutation maps
895  // requires the distribution of the mask to elementwise match the
896  // distribution of the permuted written vector. Currently the details
897  // of which lane is responsible for which element is captured strictly
898  // by shape information on the warp op, and thus requires materializing
899  // the permutation in IR.
900  if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
901  return rewriter.notifyMatchFailure(
902  read, "non-trivial permutation maps not supported");
903  VectorType maskType =
904  getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
905  additionalResults.push_back(read.getMask());
906  additionalResultTypes.push_back(maskType);
907  }
908 
909  SmallVector<size_t> newRetIndices;
910  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
911  rewriter, warpOp, additionalResults, additionalResultTypes,
912  newRetIndices);
913  distributedVal = newWarpOp.getResult(operandIndex);
914 
915  // Distributed indices were appended first.
916  SmallVector<Value> newIndices;
917  for (int64_t i = 0, e = indices.size(); i < e; ++i)
918  newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
919 
920  rewriter.setInsertionPointAfter(newWarpOp);
921  for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
922  AffineExpr d0, d1;
923  bindDims(read.getContext(), d0, d1);
924  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
925  if (!indexExpr)
926  continue;
927  unsigned indexPos = indexExpr.getPosition();
928  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
929  int64_t scale = distributedType.getDimSize(vectorPos);
930  newIndices[indexPos] = affine::makeComposedAffineApply(
931  rewriter, read.getLoc(), d0 + scale * d1,
932  {newIndices[indexPos], delinearizedIds[vectorPos]});
933  }
934 
935  // Distributed padding value was appended right after the indices.
936  Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
937  // Distributed mask value was added at the end (if the op has a mask).
938  Value newMask =
939  hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
940  : Value();
941  auto newRead = rewriter.create<vector::TransferReadOp>(
942  read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
943  read.getPermutationMapAttr(), newPadding, newMask,
944  read.getInBoundsAttr());
945 
946  rewriter.replaceAllUsesWith(distributedVal, newRead);
947  return success();
948  }
949 };
950 
951 /// Remove any result that has no use along with the matching yieldOp operand.
952 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
953 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
955  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
956  PatternRewriter &rewriter) const override {
957  SmallVector<Type> newResultTypes;
958  newResultTypes.reserve(warpOp->getNumResults());
959  SmallVector<Value> newYieldValues;
960  newYieldValues.reserve(warpOp->getNumResults());
961  DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
962  DenseMap<OpResult, int64_t> dedupResultPositionMap;
963  auto yield = cast<vector::YieldOp>(
964  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
965 
966  // Some values may be yielded multiple times and correspond to multiple
967  // results. Deduplicating occurs by taking each result with its matching
968  // yielded value, and:
969  // 1. recording the unique first position at which the value is yielded.
970  // 2. recording for the result, the first position at which the dedup'ed
971  // value is yielded.
972  // 3. skipping from the new result types / new yielded values any result
973  // that has no use or whose yielded value has already been seen.
974  for (OpResult result : warpOp.getResults()) {
975  Value yieldOperand = yield.getOperand(result.getResultNumber());
976  auto it = dedupYieldOperandPositionMap.insert(
977  std::make_pair(yieldOperand, newResultTypes.size()));
978  dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
979  if (result.use_empty() || !it.second)
980  continue;
981  newResultTypes.push_back(result.getType());
982  newYieldValues.push_back(yieldOperand);
983  }
984  // No modification, exit early.
985  if (yield.getNumOperands() == newYieldValues.size())
986  return failure();
987  // Move the body of the old warpOp to a new warpOp.
988  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
989  rewriter, warpOp, newYieldValues, newResultTypes);
990 
991  // Simplify the new warp op after dropping dead results.
992  newWarpOp.getBody()->walk([&](Operation *op) {
993  if (isOpTriviallyDead(op))
994  rewriter.eraseOp(op);
995  });
996 
997  // Replace results of the old warpOp by the new, deduplicated results.
998  SmallVector<Value> newValues;
999  newValues.reserve(warpOp->getNumResults());
1000  for (OpResult result : warpOp.getResults()) {
1001  if (result.use_empty())
1002  newValues.push_back(Value());
1003  else
1004  newValues.push_back(
1005  newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
1006  }
1007  rewriter.replaceOp(warpOp, newValues);
1008  return success();
1009  }
1010 };
1011 
1012 // If an operand is directly yielded out of the region we can forward it
1013 // directly and it doesn't need to go through the region.
1014 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
1016  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1017  PatternRewriter &rewriter) const override {
1018  SmallVector<Type> resultTypes;
1019  SmallVector<Value> yieldValues;
1020  auto yield = cast<vector::YieldOp>(
1021  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1022  Value valForwarded;
1023  unsigned resultIndex;
1024  for (OpOperand &operand : yield->getOpOperands()) {
1025  Value result = warpOp.getResult(operand.getOperandNumber());
1026  if (result.use_empty())
1027  continue;
1028 
1029  // Assume all the values coming from above are uniform.
1030  if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
1031  if (result.getType() != operand.get().getType())
1032  continue;
1033  valForwarded = operand.get();
1034  resultIndex = operand.getOperandNumber();
1035  break;
1036  }
1037  auto arg = dyn_cast<BlockArgument>(operand.get());
1038  if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1039  continue;
1040  Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1041  if (result.getType() != warpOperand.getType())
1042  continue;
1043  valForwarded = warpOperand;
1044  resultIndex = operand.getOperandNumber();
1045  break;
1046  }
1047  if (!valForwarded)
1048  return failure();
1049  // Notify the rewriter that the warp op is changing (see the comment on
1050  // the WarpOpTransferRead pattern).
1051  rewriter.startOpModification(warpOp);
1052  rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
1053  rewriter.finalizeOpModification(warpOp);
1054  return success();
1055  }
1056 };
1057 
1058 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1060  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1061  PatternRewriter &rewriter) const override {
1062  OpOperand *operand =
1063  getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1064  if (!operand)
1065  return failure();
1066  unsigned int operandNumber = operand->getOperandNumber();
1067  auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
1068  Location loc = broadcastOp.getLoc();
1069  auto destVecType =
1070  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1071  Value broadcastSrc = broadcastOp.getSource();
1072  Type broadcastSrcType = broadcastSrc.getType();
1073 
1074  // Check that the broadcast actually spans a set of values uniformly across
1075  // all threads. In other words, check that each thread can reconstruct
1076  // their own broadcast.
1077  // For that we simply check that the broadcast we want to build makes sense.
1078  if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
1080  return failure();
1081  SmallVector<size_t> newRetIndices;
1082  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1083  rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1084  rewriter.setInsertionPointAfter(newWarpOp);
1085  Value broadcasted = rewriter.create<vector::BroadcastOp>(
1086  loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1087  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1088  broadcasted);
1089  return success();
1090  }
1091 };
1092 
1093 /// Pattern to move shape cast out of the warp op. shape cast is basically a
1094 /// no-op for warp distribution; we need to handle the shape though.
1095 struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1097  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1098  PatternRewriter &rewriter) const override {
1099  OpOperand *operand =
1100  getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1101  if (!operand)
1102  return failure();
1103 
1104  auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
1105 
1106  unsigned int operandNumber = operand->getOperandNumber();
1107  auto castDistributedType =
1108  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1109  VectorType castOriginalType = oldCastOp.getSourceVectorType();
1110  VectorType castResultType = castDistributedType;
1111 
1112  // We expect the distributed type to have a smaller rank than the original
1113  // type. Prepend with size-one dimensions to make them the same.
1114  unsigned castDistributedRank = castDistributedType.getRank();
1115  unsigned castOriginalRank = castOriginalType.getRank();
1116  if (castDistributedRank < castOriginalRank) {
1117  SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1118  llvm::append_range(shape, castDistributedType.getShape());
1119  castDistributedType =
1120  VectorType::get(shape, castDistributedType.getElementType());
1121  }
1122 
1123  SmallVector<size_t> newRetIndices;
1124  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1125  rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1126  newRetIndices);
1127  rewriter.setInsertionPointAfter(newWarpOp);
1128  Value newCast = rewriter.create<vector::ShapeCastOp>(
1129  oldCastOp.getLoc(), castResultType,
1130  newWarpOp->getResult(newRetIndices[0]));
1131  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1132  return success();
1133  }
1134 };
1135 
1136 /// Sink out vector.create_mask op feeding into a warp op yield.
1137 /// ```
1138 /// %0 = ...
1139 /// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1140 /// ...
1141 /// %mask = vector.create_mask %0 : vector<32xi1>
1142 /// vector.yield %mask : vector<32xi1>
1143 /// }
1144 /// ```
1145 /// To
1146 /// ```
1147 /// %0 = ...
1148 /// vector.warp_execute_on_lane_0(%arg0) {
1149 /// ...
1150 /// }
1151 /// %cmp = arith.cmpi ult, %laneid, %0
1152 /// %ub = arith.select %cmp, %c0, %c1
1153 /// %1 = vector.create_mask %ub : vector<1xi1>
1154 struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
1156  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1157  PatternRewriter &rewriter) const override {
1158  OpOperand *yieldOperand =
1159  getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1160  if (!yieldOperand)
1161  return failure();
1162 
1163  auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1164 
1165  // Early exit if any values needed for calculating the new mask indices
1166  // are defined inside the warp op.
1167  if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1168  return warpOp.isDefinedOutsideOfRegion(value);
1169  }))
1170  return failure();
1171 
1172  Location loc = mask.getLoc();
1173  unsigned operandIndex = yieldOperand->getOperandNumber();
1174 
1175  auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1176  VectorType seqType = mask.getVectorType();
1177  ArrayRef<int64_t> seqShape = seqType.getShape();
1178  ArrayRef<int64_t> distShape = distType.getShape();
1179 
1180  rewriter.setInsertionPointAfter(warpOp);
1181 
1182  // Delinearize the lane ID for constructing the distributed mask sizes.
1183  SmallVector<Value> delinearizedIds;
1184  if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1185  warpOp.getWarpSize(), warpOp.getLaneid(),
1186  delinearizedIds))
1187  return rewriter.notifyMatchFailure(
1188  mask, "cannot delinearize lane ID for distribution");
1189  assert(!delinearizedIds.empty());
1190 
1191  // Notify the rewriter that the warp op is changing (see the comment on
1192  // the WarpOpTransferRead pattern).
1193  rewriter.startOpModification(warpOp);
1194 
1195  AffineExpr s0, s1;
1196  bindSymbols(rewriter.getContext(), s0, s1);
1197  SmallVector<Value> newOperands;
1198  for (int i = 0, e = distShape.size(); i < e; ++i) {
1199  // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1200  // find the distance from the largest mask index owned by this lane to the
1201  // original mask size. `vector.create_mask` implicitly clamps mask
1202  // operands to the range [0, mask_vector_size[i]], or in other words, the
1203  // mask sizes are always in the range [0, mask_vector_size[i]).
1205  rewriter, loc, s1 - s0 * distShape[i],
1206  {delinearizedIds[i], mask.getOperand(i)});
1207  newOperands.push_back(maskDimIdx);
1208  }
1209 
1210  auto newMask =
1211  rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1212  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1213  rewriter.finalizeOpModification(warpOp);
1214  return success();
1215  }
1216 };
1217 
1218 /// Pattern to move out vector.extract of single element vector. Those don't
1219 /// need to be distributed and can just be propagated outside of the region.
1220 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
1222  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1223  PatternRewriter &rewriter) const override {
1224  OpOperand *operand =
1225  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1226  if (!operand)
1227  return failure();
1228  unsigned int operandNumber = operand->getOperandNumber();
1229  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1230  VectorType extractSrcType = extractOp.getSourceVectorType();
1231  Location loc = extractOp.getLoc();
1232 
1233  // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
1234  assert(extractSrcType.getRank() > 0 &&
1235  "vector.extract does not support rank 0 sources");
1236 
1237  // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
1238  // canonicalized to %v.
1239  if (extractOp.getNumIndices() == 0)
1240  return failure();
1241 
1242  // Rewrite vector.extract with 1d source to vector.extractelement.
1243  if (extractSrcType.getRank() == 1) {
1244  if (extractOp.hasDynamicPosition())
1245  // TODO: Dinamic position not supported yet.
1246  return failure();
1247 
1248  assert(extractOp.getNumIndices() == 1 && "expected 1 index");
1249  int64_t pos = extractOp.getStaticPosition()[0];
1250  rewriter.setInsertionPoint(extractOp);
1251  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
1252  extractOp, extractOp.getVector(),
1253  rewriter.create<arith::ConstantIndexOp>(loc, pos));
1254  return success();
1255  }
1256 
1257  // All following cases are 2d or higher dimensional source vectors.
1258 
1259  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1260  // There is no distribution, this is a broadcast. Simply move the extract
1261  // out of the warp op.
1262  // TODO: This could be optimized. E.g., in case of a scalar result, let
1263  // one lane extract and shuffle the result to all other lanes (same as
1264  // the 1d case).
1265  SmallVector<size_t> newRetIndices;
1266  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1267  rewriter, warpOp, {extractOp.getVector()},
1268  {extractOp.getSourceVectorType()}, newRetIndices);
1269  rewriter.setInsertionPointAfter(newWarpOp);
1270  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1271  // Extract from distributed vector.
1272  Value newExtract = rewriter.create<vector::ExtractOp>(
1273  loc, distributedVec, extractOp.getMixedPosition());
1274  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1275  newExtract);
1276  return success();
1277  }
1278 
1279  // Find the distributed dimension. There should be exactly one.
1280  auto distributedType =
1281  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1282  auto yieldedType = cast<VectorType>(operand->get().getType());
1283  int64_t distributedDim = -1;
1284  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1285  if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1286  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1287  // support distributing multiple dimensions in the future.
1288  assert(distributedDim == -1 && "found multiple distributed dims");
1289  distributedDim = i;
1290  }
1291  }
1292  assert(distributedDim != -1 && "could not find distributed dimension");
1293  (void)distributedDim;
1294 
1295  // Yield source vector from warp op.
1296  SmallVector<int64_t> newDistributedShape(extractSrcType.getShape().begin(),
1297  extractSrcType.getShape().end());
1298  for (int i = 0; i < distributedType.getRank(); ++i)
1299  newDistributedShape[i + extractOp.getNumIndices()] =
1300  distributedType.getDimSize(i);
1301  auto newDistributedType =
1302  VectorType::get(newDistributedShape, distributedType.getElementType());
1303  SmallVector<size_t> newRetIndices;
1304  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1305  rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1306  newRetIndices);
1307  rewriter.setInsertionPointAfter(newWarpOp);
1308  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1309  // Extract from distributed vector.
1310  Value newExtract = rewriter.create<vector::ExtractOp>(
1311  loc, distributedVec, extractOp.getMixedPosition());
1312  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1313  newExtract);
1314  return success();
1315  }
1316 };
1317 
1318 /// Pattern to move out vector.extractelement of 0-D tensors. Those don't
1319 /// need to be distributed and can just be propagated outside of the region.
1320 struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1321  WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1322  PatternBenefit b = 1)
1323  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1324  warpShuffleFromIdxFn(std::move(fn)) {}
1325  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1326  PatternRewriter &rewriter) const override {
1327  OpOperand *operand =
1328  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1329  if (!operand)
1330  return failure();
1331  unsigned int operandNumber = operand->getOperandNumber();
1332  auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1333  VectorType extractSrcType = extractOp.getSourceVectorType();
1334  // TODO: Supported shuffle types should be parameterizable, similar to
1335  // `WarpShuffleFromIdxFn`.
1336  if (!extractSrcType.getElementType().isF32() &&
1337  !extractSrcType.getElementType().isInteger(32))
1338  return rewriter.notifyMatchFailure(
1339  extractOp, "only f32/i32 element types are supported");
1340  bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1341  Type elType = extractSrcType.getElementType();
1342  VectorType distributedVecType;
1343  if (!is0dOrVec1Extract) {
1344  assert(extractSrcType.getRank() == 1 &&
1345  "expected that extractelement src rank is 0 or 1");
1346  if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1347  return failure();
1348  int64_t elementsPerLane =
1349  extractSrcType.getShape()[0] / warpOp.getWarpSize();
1350  distributedVecType = VectorType::get({elementsPerLane}, elType);
1351  } else {
1352  distributedVecType = extractSrcType;
1353  }
1354  // Yield source vector and position (if present) from warp op.
1355  SmallVector<Value> additionalResults{extractOp.getVector()};
1356  SmallVector<Type> additionalResultTypes{distributedVecType};
1357  if (static_cast<bool>(extractOp.getPosition())) {
1358  additionalResults.push_back(extractOp.getPosition());
1359  additionalResultTypes.push_back(extractOp.getPosition().getType());
1360  }
1361  Location loc = extractOp.getLoc();
1362  SmallVector<size_t> newRetIndices;
1363  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1364  rewriter, warpOp, additionalResults, additionalResultTypes,
1365  newRetIndices);
1366  rewriter.setInsertionPointAfter(newWarpOp);
1367  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1368 
1369  // 0d extract: The new warp op broadcasts the source vector to all lanes.
1370  // All lanes extract the scalar.
1371  if (is0dOrVec1Extract) {
1372  Value newExtract;
1373  if (extractSrcType.getRank() == 1) {
1374  newExtract = rewriter.create<vector::ExtractElementOp>(
1375  loc, distributedVec,
1376  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1377 
1378  } else {
1379  newExtract =
1380  rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
1381  }
1382  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1383  newExtract);
1384  return success();
1385  }
1386 
1387  // 1d extract: Distribute the source vector. One lane extracts and shuffles
1388  // the value to all other lanes.
1389  int64_t elementsPerLane = distributedVecType.getShape()[0];
1390  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1391  // tid of extracting thread: pos / elementsPerLane
1392  Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
1393  loc, sym0.ceilDiv(elementsPerLane),
1394  newWarpOp->getResult(newRetIndices[1]));
1395  // Extract at position: pos % elementsPerLane
1396  Value pos =
1397  elementsPerLane == 1
1398  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1399  : rewriter
1400  .create<affine::AffineApplyOp>(
1401  loc, sym0 % elementsPerLane,
1402  newWarpOp->getResult(newRetIndices[1]))
1403  .getResult();
1404  Value extracted =
1405  rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
1406 
1407  // Shuffle the extracted value to all lanes.
1408  Value shuffled = warpShuffleFromIdxFn(
1409  loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1410  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1411  return success();
1412  }
1413 
1414 private:
1415  WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1416 };
1417 
1418 struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1420 
1421  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1422  PatternRewriter &rewriter) const override {
1423  OpOperand *operand =
1424  getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1425  if (!operand)
1426  return failure();
1427  unsigned int operandNumber = operand->getOperandNumber();
1428  auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1429  VectorType vecType = insertOp.getDestVectorType();
1430  VectorType distrType =
1431  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1432  bool hasPos = static_cast<bool>(insertOp.getPosition());
1433 
1434  // Yield destination vector, source scalar and position from warp op.
1435  SmallVector<Value> additionalResults{insertOp.getDest(),
1436  insertOp.getSource()};
1437  SmallVector<Type> additionalResultTypes{distrType,
1438  insertOp.getSource().getType()};
1439  if (hasPos) {
1440  additionalResults.push_back(insertOp.getPosition());
1441  additionalResultTypes.push_back(insertOp.getPosition().getType());
1442  }
1443  Location loc = insertOp.getLoc();
1444  SmallVector<size_t> newRetIndices;
1445  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1446  rewriter, warpOp, additionalResults, additionalResultTypes,
1447  newRetIndices);
1448  rewriter.setInsertionPointAfter(newWarpOp);
1449  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1450  Value newSource = newWarpOp->getResult(newRetIndices[1]);
1451  Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
1452  rewriter.setInsertionPointAfter(newWarpOp);
1453 
1454  if (vecType == distrType) {
1455  // Broadcast: Simply move the vector.inserelement op out.
1456  Value newInsert = rewriter.create<vector::InsertElementOp>(
1457  loc, newSource, distributedVec, newPos);
1458  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1459  newInsert);
1460  return success();
1461  }
1462 
1463  // This is a distribution. Only one lane should insert.
1464  int64_t elementsPerLane = distrType.getShape()[0];
1465  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1466  // tid of extracting thread: pos / elementsPerLane
1467  Value insertingLane = rewriter.create<affine::AffineApplyOp>(
1468  loc, sym0.ceilDiv(elementsPerLane), newPos);
1469  // Insert position: pos % elementsPerLane
1470  Value pos =
1471  elementsPerLane == 1
1472  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1473  : rewriter
1474  .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1475  newPos)
1476  .getResult();
1477  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1478  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1479  Value newResult =
1480  rewriter
1481  .create<scf::IfOp>(
1482  loc, isInsertingLane,
1483  /*thenBuilder=*/
1484  [&](OpBuilder &builder, Location loc) {
1485  Value newInsert = builder.create<vector::InsertElementOp>(
1486  loc, newSource, distributedVec, pos);
1487  builder.create<scf::YieldOp>(loc, newInsert);
1488  },
1489  /*elseBuilder=*/
1490  [&](OpBuilder &builder, Location loc) {
1491  builder.create<scf::YieldOp>(loc, distributedVec);
1492  })
1493  .getResult(0);
1494  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1495  return success();
1496  }
1497 };
1498 
1499 struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
1501 
1502  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1503  PatternRewriter &rewriter) const override {
1504  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1505  if (!operand)
1506  return failure();
1507  unsigned int operandNumber = operand->getOperandNumber();
1508  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1509  Location loc = insertOp.getLoc();
1510 
1511  // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
1512  if (insertOp.getNumIndices() == 0)
1513  return failure();
1514 
1515  // Rewrite vector.insert with 1d dest to vector.insertelement.
1516  if (insertOp.getDestVectorType().getRank() == 1) {
1517  if (insertOp.hasDynamicPosition())
1518  // TODO: Dinamic position not supported yet.
1519  return failure();
1520 
1521  assert(insertOp.getNumIndices() == 1 && "expected 1 index");
1522  int64_t pos = insertOp.getStaticPosition()[0];
1523  rewriter.setInsertionPoint(insertOp);
1524  rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
1525  insertOp, insertOp.getSource(), insertOp.getDest(),
1526  rewriter.create<arith::ConstantIndexOp>(loc, pos));
1527  return success();
1528  }
1529 
1530  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1531  // There is no distribution, this is a broadcast. Simply move the insert
1532  // out of the warp op.
1533  SmallVector<size_t> newRetIndices;
1534  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1535  rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1536  {insertOp.getSourceType(), insertOp.getDestVectorType()},
1537  newRetIndices);
1538  rewriter.setInsertionPointAfter(newWarpOp);
1539  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1540  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1541  Value newResult = rewriter.create<vector::InsertOp>(
1542  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1543  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1544  newResult);
1545  return success();
1546  }
1547 
1548  // Find the distributed dimension. There should be exactly one.
1549  auto distrDestType =
1550  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1551  auto yieldedType = cast<VectorType>(operand->get().getType());
1552  int64_t distrDestDim = -1;
1553  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1554  if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1555  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1556  // support distributing multiple dimensions in the future.
1557  assert(distrDestDim == -1 && "found multiple distributed dims");
1558  distrDestDim = i;
1559  }
1560  }
1561  assert(distrDestDim != -1 && "could not find distributed dimension");
1562 
1563  // Compute the distributed source vector type.
1564  VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
1565  SmallVector<int64_t> distrSrcShape(srcVecType.getShape().begin(),
1566  srcVecType.getShape().end());
1567  // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1568  // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1569  // insert a smaller vector<3xf32>.
1570  // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1571  // case, one lane will insert the source vector<96xf32>. The other
1572  // lanes will not do anything.
1573  int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1574  if (distrSrcDim >= 0)
1575  distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1576  auto distrSrcType =
1577  VectorType::get(distrSrcShape, distrDestType.getElementType());
1578 
1579  // Yield source and dest vectors from warp op.
1580  SmallVector<size_t> newRetIndices;
1581  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1582  rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1583  {distrSrcType, distrDestType}, newRetIndices);
1584  rewriter.setInsertionPointAfter(newWarpOp);
1585  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1586  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1587 
1588  // Insert into the distributed vector.
1589  Value newResult;
1590  if (distrSrcDim >= 0) {
1591  // Every lane inserts a small piece.
1592  newResult = rewriter.create<vector::InsertOp>(
1593  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1594  } else {
1595  // One lane inserts the entire source vector.
1596  int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1597  SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1598  SmallVector<int64_t> newPos = getAsIntegers(pos);
1599  // tid of inserting lane: pos / elementsPerLane
1600  Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1601  loc, newPos[distrDestDim] / elementsPerLane);
1602  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1603  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1604  // Insert position: pos % elementsPerLane
1605  newPos[distrDestDim] %= elementsPerLane;
1606  auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1607  Value newInsert = builder.create<vector::InsertOp>(
1608  loc, distributedSrc, distributedDest, newPos);
1609  builder.create<scf::YieldOp>(loc, newInsert);
1610  };
1611  auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1612  builder.create<scf::YieldOp>(loc, distributedDest);
1613  };
1614  newResult = rewriter
1615  .create<scf::IfOp>(loc, isInsertingLane,
1616  /*thenBuilder=*/insertingBuilder,
1617  /*elseBuilder=*/nonInsertingBuilder)
1618  .getResult(0);
1619  }
1620 
1621  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1622  return success();
1623  }
1624 };
1625 
1626 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1627 /// the scf.ForOp is the last operation in the region so that it doesn't change
1628 /// the order of execution. This creates a new scf.for region after the
1629 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1630 /// WarpExecuteOnLane0Op region. Example:
1631 /// ```
1632 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1633 /// ...
1634 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1635 /// -> (vector<128xf32>) {
1636 /// ...
1637 /// scf.yield %r : vector<128xf32>
1638 /// }
1639 /// vector.yield %v1 : vector<128xf32>
1640 /// }
1641 /// ```
1642 /// To:
1643 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1644 /// ...
1645 /// vector.yield %v : vector<128xf32>
1646 /// }
1647 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1648 /// -> (vector<4xf32>) {
1649 /// %iw = vector.warp_execute_on_lane_0(%laneid)
1650 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1651 /// ^bb0(%arg: vector<128xf32>):
1652 /// ...
1653 /// vector.yield %ir : vector<128xf32>
1654 /// }
1655 /// scf.yield %iw : vector<4xf32>
1656 /// }
1657 /// ```
1658 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
1659 
1660  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1661  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1662  distributionMapFn(std::move(fn)) {}
1664  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1665  PatternRewriter &rewriter) const override {
1666  auto yield = cast<vector::YieldOp>(
1667  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1668  // Only pick up forOp if it is the last op in the region.
1669  Operation *lastNode = yield->getPrevNode();
1670  auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1671  if (!forOp)
1672  return failure();
1673  // Collect Values that come from the warp op but are outside the forOp.
1674  // Those Value needs to be returned by the original warpOp and passed to the
1675  // new op.
1676  llvm::SmallSetVector<Value, 32> escapingValues;
1677  SmallVector<Type> inputTypes;
1678  SmallVector<Type> distTypes;
1680  forOp.getBodyRegion(), [&](OpOperand *operand) {
1681  Operation *parent = operand->get().getParentRegion()->getParentOp();
1682  if (warpOp->isAncestor(parent)) {
1683  if (!escapingValues.insert(operand->get()))
1684  return;
1685  Type distType = operand->get().getType();
1686  if (auto vecType = dyn_cast<VectorType>(distType)) {
1687  AffineMap map = distributionMapFn(operand->get());
1688  distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1689  }
1690  inputTypes.push_back(operand->get().getType());
1691  distTypes.push_back(distType);
1692  }
1693  });
1694 
1695  SmallVector<size_t> newRetIndices;
1696  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1697  rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1698  newRetIndices);
1699  yield = cast<vector::YieldOp>(
1700  newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1701 
1702  SmallVector<Value> newOperands;
1703  SmallVector<unsigned> resultIdx;
1704  // Collect all the outputs coming from the forOp.
1705  for (OpOperand &yieldOperand : yield->getOpOperands()) {
1706  if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1707  continue;
1708  auto forResult = cast<OpResult>(yieldOperand.get());
1709  newOperands.push_back(
1710  newWarpOp.getResult(yieldOperand.getOperandNumber()));
1711  yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1712  resultIdx.push_back(yieldOperand.getOperandNumber());
1713  }
1714 
1715  OpBuilder::InsertionGuard g(rewriter);
1716  rewriter.setInsertionPointAfter(newWarpOp);
1717 
1718  // Create a new for op outside the region with a WarpExecuteOnLane0Op region
1719  // inside.
1720  auto newForOp = rewriter.create<scf::ForOp>(
1721  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1722  forOp.getStep(), newOperands);
1723  rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
1724 
1725  SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1726  newForOp.getRegionIterArgs().end());
1727  SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1728  forOp.getResultTypes().end());
1729  llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1730  for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1731  warpInput.push_back(newWarpOp.getResult(retIdx));
1732  argIndexMapping[escapingValues[i]] = warpInputType.size();
1733  warpInputType.push_back(inputTypes[i]);
1734  }
1735  auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1736  newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1737  newWarpOp.getWarpSize(), warpInput, warpInputType);
1738 
1739  SmallVector<Value> argMapping;
1740  argMapping.push_back(newForOp.getInductionVar());
1741  for (Value args : innerWarp.getBody()->getArguments()) {
1742  argMapping.push_back(args);
1743  }
1744  argMapping.resize(forOp.getBody()->getNumArguments());
1745  SmallVector<Value> yieldOperands;
1746  for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1747  yieldOperands.push_back(operand);
1748  rewriter.eraseOp(forOp.getBody()->getTerminator());
1749  rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1750  rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
1751  rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
1752  rewriter.setInsertionPointAfter(innerWarp);
1753  if (!innerWarp.getResults().empty())
1754  rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1755  rewriter.eraseOp(forOp);
1756  // Replace the warpOp result coming from the original ForOp.
1757  for (const auto &res : llvm::enumerate(resultIdx)) {
1758  rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1759  newForOp.getResult(res.index()));
1760  newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1761  }
1762  newForOp.walk([&](Operation *op) {
1763  for (OpOperand &operand : op->getOpOperands()) {
1764  auto it = argIndexMapping.find(operand.get());
1765  if (it == argIndexMapping.end())
1766  continue;
1767  operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1768  }
1769  });
1770 
1771  // Finally, hoist out any now uniform code from the inner warp op.
1772  mlir::vector::moveScalarUniformCode(innerWarp);
1773  return success();
1774  }
1775 
1776 private:
1777  DistributionMapFn distributionMapFn;
1778 };
1779 
1780 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1781 /// The vector is reduced in parallel. Currently limited to vector size matching
1782 /// the warpOp size. E.g.:
1783 /// ```
1784 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1785 /// %0 = "some_def"() : () -> (vector<32xf32>)
1786 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1787 /// vector_ext.yield %1 : f32
1788 /// }
1789 /// ```
1790 /// is lowered to:
1791 /// ```
1792 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1793 /// %1 = "some_def"() : () -> (vector<32xf32>)
1794 /// vector_ext.yield %1 : vector<32xf32>
1795 /// }
1796 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
1797 /// %r = ("warp.reduction %a")
1798 /// ```
1799 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
1800  WarpOpReduction(MLIRContext *context,
1801  DistributedReductionFn distributedReductionFn,
1802  PatternBenefit benefit = 1)
1803  : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
1804  distributedReductionFn(std::move(distributedReductionFn)) {}
1805 
1806  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1807  PatternRewriter &rewriter) const override {
1808  OpOperand *yieldOperand =
1809  getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1810  if (!yieldOperand)
1811  return failure();
1812 
1813  auto reductionOp =
1814  cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1815  auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1816  // Only rank 1 vectors supported.
1817  if (vectorType.getRank() != 1)
1818  return rewriter.notifyMatchFailure(
1819  warpOp, "Only rank 1 reductions can be distributed.");
1820  // Only warp_size-sized vectors supported.
1821  if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1822  return rewriter.notifyMatchFailure(
1823  warpOp, "Reduction vector dimension must match was size.");
1824  if (!reductionOp.getType().isIntOrFloat())
1825  return rewriter.notifyMatchFailure(
1826  warpOp, "Reduction distribution currently only supports floats and "
1827  "integer types.");
1828 
1829  int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1830  // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1831  unsigned operandIndex = yieldOperand->getOperandNumber();
1832  SmallVector<Value> yieldValues = {reductionOp.getVector()};
1833  SmallVector<Type> retTypes = {
1834  VectorType::get({numElements}, reductionOp.getType())};
1835  if (reductionOp.getAcc()) {
1836  yieldValues.push_back(reductionOp.getAcc());
1837  retTypes.push_back(reductionOp.getAcc().getType());
1838  }
1839  SmallVector<size_t> newRetIndices;
1840  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1841  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1842  rewriter.setInsertionPointAfter(newWarpOp);
1843 
1844  // Obtain data to reduce for a single lane.
1845  Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1846  // Distribute and reduce across threads.
1847  Value fullReduce =
1848  distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1849  reductionOp.getKind(), newWarpOp.getWarpSize());
1850  if (reductionOp.getAcc()) {
1851  fullReduce = vector::makeArithReduction(
1852  rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1853  newWarpOp.getResult(newRetIndices[1]));
1854  }
1855  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1856  return success();
1857  }
1858 
1859 private:
1860  DistributedReductionFn distributedReductionFn;
1861 };
1862 
1863 } // namespace
1864 
1866  RewritePatternSet &patterns,
1868  patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
1869 }
1870 
1871 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1872  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1873  unsigned maxNumElementsToExtract, PatternBenefit benefit) {
1874  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
1875  maxNumElementsToExtract, benefit);
1876 }
1877 
1878 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1879  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1880  const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1881  PatternBenefit readBenefit) {
1882  patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1883  patterns
1884  .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1885  WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
1886  WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
1887  patterns.getContext(), benefit);
1888  patterns.add<WarpOpExtractElement>(patterns.getContext(),
1889  warpShuffleFromIdxFn, benefit);
1890  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
1891  benefit);
1892 }
1893 
1894 void mlir::vector::populateDistributeReduction(
1895  RewritePatternSet &patterns,
1896  const DistributedReductionFn &distributedReductionFn,
1897  PatternBenefit benefit) {
1898  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
1899  benefit);
1900 }
1901 
1902 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1903  Block *body = warpOp.getBody();
1904 
1905  // Keep track of the ops we want to hoist.
1906  llvm::SmallSetVector<Operation *, 8> opsToMove;
1907 
1908  // Helper to check if a value is or will be defined outside of the region.
1909  auto isDefinedOutsideOfBody = [&](Value value) {
1910  auto *definingOp = value.getDefiningOp();
1911  return (definingOp && opsToMove.count(definingOp)) ||
1912  warpOp.isDefinedOutsideOfRegion(value);
1913  };
1914 
1915  // Do not use walk here, as we do not want to go into nested regions and hoist
1916  // operations from there.
1917  for (auto &op : body->without_terminator()) {
1918  bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1919  return isa<VectorType>(result.getType());
1920  });
1921  if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1922  opsToMove.insert(&op);
1923  }
1924 
1925  // Move all the ops marked as uniform outside of the region.
1926  for (Operation *op : opsToMove)
1927  op->moveBefore(warpOp);
1928 }
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, const std::function< bool(Operation *)> &fn)
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes)
Helper to create a new WarpExecuteOnLane0Op with different signature.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, llvm::SmallVector< size_t > &indices)
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:883
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:926
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:399
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
unsigned getNumResults() const
Definition: AffineMap.cpp:386
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:540
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:329
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
Operation & front()
Definition: Block.h:150
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:206
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:379
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:214
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1391
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1138
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2236
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:284
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:683
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:36
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:609
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.