MLIR  20.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/AffineExpr.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include <numeric>
21 #include <utility>
22 
23 using namespace mlir;
24 using namespace mlir::vector;
25 
26 /// Currently the distribution map is implicit based on the vector shape. In the
27 /// future it will be part of the op.
28 /// Example:
29 /// ```
30 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
31 /// ...
32 /// vector.yield %3 : vector<32x16x64xf32>
33 /// }
34 /// ```
35 /// Would have an implicit map of:
36 /// `(d0, d1, d2) -> (d0, d2)`
37 static AffineMap calculateImplicitMap(VectorType sequentialType,
38  VectorType distributedType) {
40  perm.reserve(1);
41  // Check which dimensions of the sequential type are different than the
42  // dimensions of the distributed type to know the distributed dimensions. Then
43  // associate each distributed dimension to an ID in order.
44  for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
45  if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
46  perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
47  }
48  auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
49  distributedType.getContext());
50  return map;
51 }
52 
53 namespace {
54 
55 /// Helper struct to create the load / store operations that permit transit
56 /// through the parallel / sequential and the sequential / parallel boundaries
57 /// when performing `rewriteWarpOpToScfFor`.
58 ///
59 /// The vector distribution dimension is inferred from the vector types.
60 struct DistributedLoadStoreHelper {
61  DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
62  Value laneId, Value zero)
63  : sequentialVal(sequentialVal), distributedVal(distributedVal),
64  laneId(laneId), zero(zero) {
65  sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
66  distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
67  if (sequentialVectorType && distributedVectorType)
68  distributionMap =
69  calculateImplicitMap(sequentialVectorType, distributedVectorType);
70  }
71 
72  Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
73  int64_t distributedSize = distributedVectorType.getDimSize(index);
75  return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
76  ArrayRef<Value>{laneId});
77  }
78 
79  /// Create a store during the process of distributing the
80  /// `vector.warp_execute_on_thread_0` op.
81  /// Vector distribution assumes the following convention regarding the
82  /// temporary buffers that are created to transition values. This **must**
83  /// be properly specified in the `options.warpAllocationFn`:
84  /// 1. scalars of type T transit through a memref<1xT>.
85  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
86  Operation *buildStore(RewriterBase &b, Location loc, Value val,
87  Value buffer) {
88  assert((val == distributedVal || val == sequentialVal) &&
89  "Must store either the preregistered distributed or the "
90  "preregistered sequential value.");
91  // Scalar case can directly use memref.store.
92  if (!isa<VectorType>(val.getType()))
93  return b.create<memref::StoreOp>(loc, val, buffer, zero);
94 
95  // Vector case must use vector::TransferWriteOp which will later lower to
96  // vector.store of memref.store depending on further lowerings.
97  int64_t rank = sequentialVectorType.getRank();
98  SmallVector<Value> indices(rank, zero);
99  if (val == distributedVal) {
100  for (auto dimExpr : distributionMap.getResults()) {
101  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
102  indices[index] = buildDistributedOffset(b, loc, index);
103  }
104  }
105  SmallVector<bool> inBounds(indices.size(), true);
106  return b.create<vector::TransferWriteOp>(
107  loc, val, buffer, indices,
108  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
109  }
110 
111  /// Create a load during the process of distributing the
112  /// `vector.warp_execute_on_thread_0` op.
113  /// Vector distribution assumes the following convention regarding the
114  /// temporary buffers that are created to transition values. This **must**
115  /// be properly specified in the `options.warpAllocationFn`:
116  /// 1. scalars of type T transit through a memref<1xT>.
117  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
118  ///
119  /// When broadcastMode is true, the load is not distributed to account for
120  /// the broadcast semantics of the `vector.warp_execute_on_lane_0` op.
121  ///
122  /// Example:
123  ///
124  /// ```
125  /// %r = vector.warp_execute_on_lane_0(...) -> (f32) {
126  /// vector.yield %cst : f32
127  /// }
128  /// // Both types are f32. The constant %cst is broadcasted to all lanes.
129  /// ```
130  /// This behavior described in more detail in the documentation of the op.
131  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
132 
133  // Scalar case can directly use memref.store.
134  if (!isa<VectorType>(type))
135  return b.create<memref::LoadOp>(loc, buffer, zero);
136 
137  // Other cases must be vector atm.
138  // Vector case must use vector::TransferReadOp which will later lower to
139  // vector.read of memref.read depending on further lowerings.
140  assert((type == distributedVectorType || type == sequentialVectorType) &&
141  "Must store either the preregistered distributed or the "
142  "preregistered sequential type.");
143  SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
144  if (type == distributedVectorType) {
145  for (auto dimExpr : distributionMap.getResults()) {
146  int64_t index = cast<AffineDimExpr>(dimExpr).getPosition();
147  indices[index] = buildDistributedOffset(b, loc, index);
148  }
149  }
150  SmallVector<bool> inBounds(indices.size(), true);
151  return b.create<vector::TransferReadOp>(
152  loc, cast<VectorType>(type), buffer, indices,
153  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
154  }
155 
156  Value sequentialVal, distributedVal, laneId, zero;
157  VectorType sequentialVectorType, distributedVectorType;
158  AffineMap distributionMap;
159 };
160 
161 } // namespace
162 
163 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
164 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
165  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
166  ValueRange newYieldedValues, TypeRange newReturnTypes) {
167  // Create a new op before the existing one, with the extra operands.
168  OpBuilder::InsertionGuard g(rewriter);
169  rewriter.setInsertionPoint(warpOp);
170  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
171  warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
172  warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
173 
174  Region &opBody = warpOp.getBodyRegion();
175  Region &newOpBody = newWarpOp.getBodyRegion();
176  Block &newOpFirstBlock = newOpBody.front();
177  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
178  rewriter.eraseBlock(&newOpFirstBlock);
179  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
180  "expected WarpOp with single block");
181 
182  auto yield =
183  cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
184 
185  rewriter.modifyOpInPlace(
186  yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
187  return newWarpOp;
188 }
189 
190 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
191 /// `indices` return the index of each new output.
192 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
193  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
194  ValueRange newYieldedValues, TypeRange newReturnTypes,
195  llvm::SmallVector<size_t> &indices) {
196  SmallVector<Type> types(warpOp.getResultTypes().begin(),
197  warpOp.getResultTypes().end());
198  auto yield = cast<vector::YieldOp>(
199  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
200  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
201  yield.getOperands().end());
202  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
203  if (yieldValues.insert(std::get<0>(newRet))) {
204  types.push_back(std::get<1>(newRet));
205  indices.push_back(yieldValues.size() - 1);
206  } else {
207  // If the value already exit the region don't create a new output.
208  for (auto [idx, yieldOperand] :
209  llvm::enumerate(yieldValues.getArrayRef())) {
210  if (yieldOperand == std::get<0>(newRet)) {
211  indices.push_back(idx);
212  break;
213  }
214  }
215  }
216  }
217  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
218  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
219  rewriter, warpOp, yieldValues.getArrayRef(), types);
220  rewriter.replaceOp(warpOp,
221  newWarpOp.getResults().take_front(warpOp.getNumResults()));
222  return newWarpOp;
223 }
224 
225 /// Helper to know if an op can be hoisted out of the region.
226 static bool canBeHoisted(Operation *op,
227  function_ref<bool(Value)> definedOutside) {
228  return llvm::all_of(op->getOperands(), definedOutside) &&
229  isMemoryEffectFree(op) && op->getNumRegions() == 0;
230 }
231 
232 /// Return a value yielded by `warpOp` which statifies the filter lamdba
233 /// condition and is not dead.
234 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
235  const std::function<bool(Operation *)> &fn) {
236  auto yield = cast<vector::YieldOp>(
237  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
238  for (OpOperand &yieldOperand : yield->getOpOperands()) {
239  Value yieldValues = yieldOperand.get();
240  Operation *definedOp = yieldValues.getDefiningOp();
241  if (definedOp && fn(definedOp)) {
242  if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
243  return &yieldOperand;
244  }
245  }
246  return {};
247 }
248 
249 // Clones `op` into a new operation that takes `operands` and returns
250 // `resultTypes`.
252  Location loc, Operation *op,
253  ArrayRef<Value> operands,
254  ArrayRef<Type> resultTypes) {
255  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
256  op->getAttrs());
257  return rewriter.create(res);
258 }
259 
260 namespace {
261 
262 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
263 /// thread `laneId` executes the entirety of the computation.
264 ///
265 /// After the transformation:
266 /// - the IR within the scf.if op can be thought of as executing sequentially
267 /// (from the point of view of threads along `laneId`).
268 /// - the IR outside of the scf.if op can be thought of as executing in
269 /// parallel (from the point of view of threads along `laneId`).
270 ///
271 /// Values that need to transit through the parallel / sequential and the
272 /// sequential / parallel boundaries do so via reads and writes to a temporary
273 /// memory location.
274 ///
275 /// The transformation proceeds in multiple steps:
276 /// 1. Create the scf.if op.
277 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
278 /// within the scf.if to transit the values captured from above.
279 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
280 /// consistent within the scf.if.
281 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
282 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to
283 /// transit the values returned by the op.
284 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
285 /// consistent after the scf.if.
286 /// 7. Perform late cleanups.
287 ///
288 /// All this assumes the vector distribution occurs along the most minor
289 /// distributed vector dimension.
290 struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
291  WarpOpToScfIfPattern(MLIRContext *context,
293  PatternBenefit benefit = 1)
294  : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
295  options(options) {}
296 
297  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
298  PatternRewriter &rewriter) const override {
299  assert(warpOp.getBodyRegion().hasOneBlock() &&
300  "expected WarpOp with single block");
301  Block *warpOpBody = &warpOp.getBodyRegion().front();
302  Location loc = warpOp.getLoc();
303 
304  // Passed all checks. Start rewriting.
305  OpBuilder::InsertionGuard g(rewriter);
306  rewriter.setInsertionPoint(warpOp);
307 
308  // Step 1: Create scf.if op.
309  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
310  Value isLane0 = rewriter.create<arith::CmpIOp>(
311  loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
312  auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
313  /*withElseRegion=*/false);
314  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
315 
316  // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
317  // reads within the scf.if to transit the values captured from above.
318  SmallVector<Value> bbArgReplacements;
319  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
320  Value sequentialVal = warpOpBody->getArgument(it.index());
321  Value distributedVal = it.value();
322  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
323  warpOp.getLaneid(), c0);
324 
325  // Create buffer before the ifOp.
326  rewriter.setInsertionPoint(ifOp);
327  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
328  sequentialVal.getType());
329  // Store distributed vector into buffer, before the ifOp.
330  helper.buildStore(rewriter, loc, distributedVal, buffer);
331  // Load sequential vector from buffer, inside the ifOp.
332  rewriter.setInsertionPointToStart(ifOp.thenBlock());
333  bbArgReplacements.push_back(
334  helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
335  }
336 
337  // Step 3. Insert sync after all the stores and before all the loads.
338  if (!warpOp.getArgs().empty()) {
339  rewriter.setInsertionPoint(ifOp);
340  options.warpSyncronizationFn(loc, rewriter, warpOp);
341  }
342 
343  // Step 4. Move body of warpOp to ifOp.
344  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
345 
346  // Step 5. Insert appropriate writes within scf.if and reads after the
347  // scf.if to transit the values returned by the op.
348  // TODO: at this point, we can reuse the shared memory from previous
349  // buffers.
350  SmallVector<Value> replacements;
351  auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
352  Location yieldLoc = yieldOp.getLoc();
353  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
354  Value sequentialVal = it.value();
355  Value distributedVal = warpOp->getResult(it.index());
356  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
357  warpOp.getLaneid(), c0);
358 
359  // Create buffer before the ifOp.
360  rewriter.setInsertionPoint(ifOp);
361  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
362  sequentialVal.getType());
363 
364  // Store yielded value into buffer, inside the ifOp, before the
365  // terminator.
366  rewriter.setInsertionPoint(yieldOp);
367  helper.buildStore(rewriter, loc, sequentialVal, buffer);
368 
369  // Load distributed value from buffer, after the warpOp.
370  rewriter.setInsertionPointAfter(ifOp);
371  // Result type and yielded value type are the same. This is a broadcast.
372  // E.g.:
373  // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
374  // vector.yield %cst : f32
375  // }
376  // Both types are f32. The constant %cst is broadcasted to all lanes.
377  // This is described in more detail in the documentation of the op.
378  replacements.push_back(
379  helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
380  }
381 
382  // Step 6. Insert sync after all the stores and before all the loads.
383  if (!yieldOp.getOperands().empty()) {
384  rewriter.setInsertionPointAfter(ifOp);
385  options.warpSyncronizationFn(loc, rewriter, warpOp);
386  }
387 
388  // Step 7. Delete terminator and add empty scf.yield.
389  rewriter.eraseOp(yieldOp);
390  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
391  rewriter.create<scf::YieldOp>(yieldLoc);
392 
393  // Compute replacements for WarpOp results.
394  rewriter.replaceOp(warpOp, replacements);
395 
396  return success();
397  }
398 
399 private:
401 };
402 
403 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
404 /// op with the proper return type.
405 /// The new write op is updated to write the result of the new warp execute op.
406 /// The old `writeOp` is deleted.
407 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
408  WarpExecuteOnLane0Op warpOp,
409  vector::TransferWriteOp writeOp,
410  VectorType targetType,
411  VectorType maybeMaskType) {
412  assert(writeOp->getParentOp() == warpOp &&
413  "write must be nested immediately under warp");
414  OpBuilder::InsertionGuard g(rewriter);
415  SmallVector<size_t> newRetIndices;
416  WarpExecuteOnLane0Op newWarpOp;
417  if (maybeMaskType) {
419  rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()},
420  TypeRange{targetType, maybeMaskType}, newRetIndices);
421  } else {
423  rewriter, warpOp, ValueRange{{writeOp.getVector()}},
424  TypeRange{targetType}, newRetIndices);
425  }
426  rewriter.setInsertionPointAfter(newWarpOp);
427  auto newWriteOp =
428  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
429  rewriter.eraseOp(writeOp);
430  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
431  if (maybeMaskType)
432  newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1]));
433  return newWriteOp;
434 }
435 
436 /// Return the distributed vector type based on the original type and the
437 /// distribution map. The map is expected to have a dimension equal to the
438 /// original type rank and should be a projection where the results are the
439 /// distributed dimensions. The number of results should be equal to the number
440 /// of warp sizes which is currently limited to 1.
441 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
442 /// and a warp size of 16 would distribute the second dimension (associated to
443 /// d1) and return vector<16x2x64>
444 static VectorType getDistributedType(VectorType originalType, AffineMap map,
445  int64_t warpSize) {
446  SmallVector<int64_t> targetShape(originalType.getShape());
447  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
448  unsigned position = map.getDimPosition(i);
449  if (targetShape[position] % warpSize != 0) {
450  if (warpSize % targetShape[position] != 0) {
451  return VectorType();
452  }
453  warpSize /= targetShape[position];
454  targetShape[position] = 1;
455  continue;
456  }
457  targetShape[position] = targetShape[position] / warpSize;
458  warpSize = 1;
459  break;
460  }
461  if (warpSize != 1) {
462  return VectorType();
463  }
464  VectorType targetType =
465  VectorType::get(targetShape, originalType.getElementType());
466  return targetType;
467 }
468 
469 /// Distribute transfer_write ops based on the affine map returned by
470 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
471 /// will not be distributed (it should be less than the warp size).
472 ///
473 /// Example:
474 /// ```
475 /// %0 = vector.warp_execute_on_lane_0(%id){
476 /// ...
477 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
478 /// vector.yield
479 /// }
480 /// ```
481 /// To
482 /// ```
483 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
484 /// ...
485 /// vector.yield %v : vector<32xf32>
486 /// }
487 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
488 struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> {
489  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
490  unsigned maxNumElementsToExtract, PatternBenefit b = 1)
491  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
492  distributionMapFn(std::move(fn)),
493  maxNumElementsToExtract(maxNumElementsToExtract) {}
494 
495  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
496  /// are multiples of the distribution ratio are supported at the moment.
497  LogicalResult tryDistributeOp(RewriterBase &rewriter,
498  vector::TransferWriteOp writeOp,
499  WarpExecuteOnLane0Op warpOp) const {
500  VectorType writtenVectorType = writeOp.getVectorType();
501 
502  // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
503  // to separate it from the rest.
504  if (writtenVectorType.getRank() == 0)
505  return failure();
506 
507  // 2. Compute the distributed type.
508  AffineMap map = distributionMapFn(writeOp.getVector());
509  VectorType targetType =
510  getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
511  if (!targetType)
512  return failure();
513 
514  // 2.5 Compute the distributed type for the new mask;
515  VectorType maskType;
516  if (writeOp.getMask()) {
517  // TODO: Distribution of masked writes with non-trivial permutation maps
518  // requires the distribution of the mask to elementwise match the
519  // distribution of the permuted written vector. Currently the details
520  // of which lane is responsible for which element is captured strictly
521  // by shape information on the warp op, and thus requires materializing
522  // the permutation in IR.
523  if (!writeOp.getPermutationMap().isMinorIdentity())
524  return failure();
525  maskType =
526  getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize());
527  }
528 
529  // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
530  // the rest.
531  vector::TransferWriteOp newWriteOp =
532  cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType);
533 
534  // 4. Reindex the write using the distribution map.
535  auto newWarpOp =
536  newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
537 
538  // Delinearize the lane id based on the way threads are divided across the
539  // vector. To get the number of threads per vector dimension, divide the
540  // sequential size by the distributed size along each dim.
541  rewriter.setInsertionPoint(newWriteOp);
542  SmallVector<OpFoldResult> delinearizedIdSizes;
543  for (auto [seqSize, distSize] :
544  llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) {
545  assert(seqSize % distSize == 0 && "Invalid distributed vector shape");
546  delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize));
547  }
548  SmallVector<Value> delinearized;
549  if (map.getNumResults() > 1) {
550  delinearized = rewriter
551  .create<mlir::affine::AffineDelinearizeIndexOp>(
552  newWarpOp.getLoc(), newWarpOp.getLaneid(),
553  delinearizedIdSizes)
554  .getResults();
555  } else {
556  // If there is only one map result, we can elide the delinearization
557  // op and use the lane id directly.
558  delinearized.append(targetType.getRank(), newWarpOp.getLaneid());
559  }
560 
561  AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
562  Location loc = newWriteOp.getLoc();
563  SmallVector<Value> indices(newWriteOp.getIndices().begin(),
564  newWriteOp.getIndices().end());
565  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
566  AffineExpr d0, d1;
567  bindDims(newWarpOp.getContext(), d0, d1);
568  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
569  if (!indexExpr)
570  continue;
571  unsigned indexPos = indexExpr.getPosition();
572  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
573  Value laneId = delinearized[vectorPos];
574  auto scale =
575  rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
576  indices[indexPos] = affine::makeComposedAffineApply(
577  rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId});
578  }
579  newWriteOp.getIndicesMutable().assign(indices);
580 
581  return success();
582  }
583 
584  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
585  LogicalResult tryExtractOp(RewriterBase &rewriter,
586  vector::TransferWriteOp writeOp,
587  WarpExecuteOnLane0Op warpOp) const {
588  Location loc = writeOp.getLoc();
589  VectorType vecType = writeOp.getVectorType();
590 
591  if (vecType.getNumElements() > maxNumElementsToExtract) {
592  return rewriter.notifyMatchFailure(
593  warpOp,
594  llvm::formatv(
595  "writes more elements ({0}) than allowed to extract ({1})",
596  vecType.getNumElements(), maxNumElementsToExtract));
597  }
598 
599  // Do not process warp ops that contain only TransferWriteOps.
600  if (llvm::all_of(warpOp.getOps(),
601  llvm::IsaPred<vector::TransferWriteOp, vector::YieldOp>))
602  return failure();
603 
604  SmallVector<Value> yieldValues = {writeOp.getVector()};
605  SmallVector<Type> retTypes = {vecType};
606  SmallVector<size_t> newRetIndices;
607  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
608  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
609  rewriter.setInsertionPointAfter(newWarpOp);
610 
611  // Create a second warp op that contains only writeOp.
612  auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
613  loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
614  Block &body = secondWarpOp.getBodyRegion().front();
615  rewriter.setInsertionPointToStart(&body);
616  auto newWriteOp =
617  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
618  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
619  rewriter.eraseOp(writeOp);
620  rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
621  return success();
622  }
623 
624  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
625  PatternRewriter &rewriter) const override {
626  auto yield = cast<vector::YieldOp>(
627  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
628  Operation *lastNode = yield->getPrevNode();
629  auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
630  if (!writeOp)
631  return failure();
632 
633  Value maybeMask = writeOp.getMask();
634  if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
635  return writeOp.getVector() == value ||
636  (maybeMask && maybeMask == value) ||
637  warpOp.isDefinedOutsideOfRegion(value);
638  }))
639  return failure();
640 
641  if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
642  return success();
643 
644  // Masked writes not supported for extraction.
645  if (writeOp.getMask())
646  return failure();
647 
648  if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
649  return success();
650 
651  return failure();
652  }
653 
654 private:
655  DistributionMapFn distributionMapFn;
656  unsigned maxNumElementsToExtract = 1;
657 };
658 
659 /// Sink out elementwise op feeding into a warp op yield.
660 /// ```
661 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
662 /// ...
663 /// %3 = arith.addf %1, %2 : vector<32xf32>
664 /// vector.yield %3 : vector<32xf32>
665 /// }
666 /// ```
667 /// To
668 /// ```
669 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
670 /// vector<1xf32>, vector<1xf32>) {
671 /// ...
672 /// %4 = arith.addf %2, %3 : vector<32xf32>
673 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
674 /// vector<32xf32>
675 /// }
676 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
677 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
679  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
680  PatternRewriter &rewriter) const override {
681  OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
683  });
684  if (!yieldOperand)
685  return failure();
686 
687  Operation *elementWise = yieldOperand->get().getDefiningOp();
688  unsigned operandIndex = yieldOperand->getOperandNumber();
689  Value distributedVal = warpOp.getResult(operandIndex);
690  SmallVector<Value> yieldValues;
691  SmallVector<Type> retTypes;
692  Location loc = warpOp.getLoc();
693  for (OpOperand &operand : elementWise->getOpOperands()) {
694  Type targetType;
695  if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
696  // If the result type is a vector, the operands must also be vectors.
697  auto operandType = cast<VectorType>(operand.get().getType());
698  targetType =
699  VectorType::get(vecType.getShape(), operandType.getElementType());
700  } else {
701  auto operandType = operand.get().getType();
702  assert(!isa<VectorType>(operandType) &&
703  "unexpected yield of vector from op with scalar result type");
704  targetType = operandType;
705  }
706  retTypes.push_back(targetType);
707  yieldValues.push_back(operand.get());
708  }
709  SmallVector<size_t> newRetIndices;
710  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
711  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
712  rewriter.setInsertionPointAfter(newWarpOp);
713  SmallVector<Value> newOperands(elementWise->getOperands().begin(),
714  elementWise->getOperands().end());
715  for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
716  newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
717  }
718  OpBuilder::InsertionGuard g(rewriter);
719  rewriter.setInsertionPointAfter(newWarpOp);
721  rewriter, loc, elementWise, newOperands,
722  {newWarpOp.getResult(operandIndex).getType()});
723  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
724  newOp->getResult(0));
725  return success();
726  }
727 };
728 
729 /// Sink out splat constant op feeding into a warp op yield.
730 /// ```
731 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
732 /// ...
733 /// %cst = arith.constant dense<2.0> : vector<32xf32>
734 /// vector.yield %cst : vector<32xf32>
735 /// }
736 /// ```
737 /// To
738 /// ```
739 /// vector.warp_execute_on_lane_0(%arg0 {
740 /// ...
741 /// }
742 /// %0 = arith.constant dense<2.0> : vector<1xf32>
743 struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
745  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
746  PatternRewriter &rewriter) const override {
747  OpOperand *yieldOperand =
748  getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>);
749  if (!yieldOperand)
750  return failure();
751  auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
752  auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
753  if (!dense)
754  return failure();
755  // Notify the rewriter that the warp op is changing (see the comment on
756  // the WarpOpTransferRead pattern).
757  rewriter.startOpModification(warpOp);
758  unsigned operandIndex = yieldOperand->getOperandNumber();
759  Attribute scalarAttr = dense.getSplatValue<Attribute>();
760  auto newAttr = DenseElementsAttr::get(
761  cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
762  Location loc = warpOp.getLoc();
763  rewriter.setInsertionPointAfter(warpOp);
764  Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
765  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
766  rewriter.finalizeOpModification(warpOp);
767  return success();
768  }
769 };
770 
771 /// Delinearize the given `laneId` into multiple dimensions, where each
772 /// dimension's size is determined by `originalShape` and `distributedShape`
773 /// together. This function expects the total numbers of threads needed for
774 /// distribution is equal to `warpSize`. Returns true and updates
775 /// `delinearizedIds` if so.
776 bool delinearizeLaneId(OpBuilder &builder, Location loc,
777  ArrayRef<int64_t> originalShape,
778  ArrayRef<int64_t> distributedShape, int64_t warpSize,
779  Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
780  // If the original shape and the distributed shape is the same, we don't
781  // distribute at all--every thread is handling the whole. For such case, we
782  // should not rely on lane IDs later. So just return an empty lane ID vector.
783  if (originalShape == distributedShape) {
784  delinearizedIds.clear();
785  return true;
786  }
787 
788  SmallVector<int64_t> sizes;
789  for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
790  if (large % small != 0)
791  return false;
792  sizes.push_back(large / small);
793  }
794  if (std::accumulate(sizes.begin(), sizes.end(), 1,
795  std::multiplies<int64_t>()) != warpSize)
796  return false;
797 
798  AffineExpr s0, s1;
799  bindSymbols(builder.getContext(), s0, s1);
800 
801  int64_t usedThreads = 1;
802 
803  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
804  delinearizedIds.assign(sizes.size(), zero);
805 
806  for (int i = sizes.size() - 1; i >= 0; --i) {
807  usedThreads *= sizes[i];
808  if (usedThreads == warpSize) {
809  // We've used up all available threads. Don't need to perform modulo
810  // anymore. And we can stop the calculation for further dimensions.
811  delinearizedIds[i] = laneId;
812  break;
813  }
814  delinearizedIds[i] =
815  affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
817  builder, loc, s0.floorDiv(usedThreads), {laneId});
818  }
819  return true;
820 }
821 
822 /// Sink out transfer_read op feeding into a warp op yield.
823 /// ```
824 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
825 /// ...
826 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
827 // vector<32xf32>
828 /// vector.yield %2 : vector<32xf32>
829 /// }
830 /// ```
831 /// To
832 /// ```
833 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
834 /// vector<1xf32>, vector<1xf32>) {
835 /// ...
836 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
837 /// vector<32xf32> vector.yield %2 : vector<32xf32>
838 /// }
839 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
840 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
842  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
843  PatternRewriter &rewriter) const override {
844  // Try to find a distributable yielded read. Note that this pattern can
845  // still fail at the end after distribution, in which case this might have
846  // missed another distributable read.
847  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
848  // Don't duplicate transfer_read ops when distributing.
849  return isa<vector::TransferReadOp>(op) && op->hasOneUse();
850  });
851  if (!operand)
852  return rewriter.notifyMatchFailure(
853  warpOp, "warp result is not a vector.transfer_read op");
854  auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
855 
856  // Source must be defined outside of the region.
857  if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
858  return rewriter.notifyMatchFailure(
859  read, "source must be defined outside of the region");
860 
861  unsigned operandIndex = operand->getOperandNumber();
862  Value distributedVal = warpOp.getResult(operandIndex);
863 
864  SmallVector<Value, 4> indices(read.getIndices().begin(),
865  read.getIndices().end());
866  auto sequentialType = cast<VectorType>(read.getResult().getType());
867  auto distributedType = cast<VectorType>(distributedVal.getType());
868  AffineMap map = calculateImplicitMap(sequentialType, distributedType);
869  AffineMap indexMap = map.compose(read.getPermutationMap());
870 
871  // Try to delinearize the lane ID to match the rank expected for
872  // distribution.
873  SmallVector<Value> delinearizedIds;
874  if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
875  distributedType.getShape(), warpOp.getWarpSize(),
876  warpOp.getLaneid(), delinearizedIds)) {
877  return rewriter.notifyMatchFailure(
878  read, "cannot delinearize lane ID for distribution");
879  }
880  assert(!delinearizedIds.empty() || map.getNumResults() == 0);
881 
882  // Distribute indices and the mask (if present).
883  OpBuilder::InsertionGuard g(rewriter);
884  SmallVector<Value> additionalResults(indices.begin(), indices.end());
885  SmallVector<Type> additionalResultTypes(indices.size(),
886  rewriter.getIndexType());
887  additionalResults.push_back(read.getPadding());
888  additionalResultTypes.push_back(read.getPadding().getType());
889 
890  bool hasMask = false;
891  if (read.getMask()) {
892  hasMask = true;
893  // TODO: Distribution of masked reads with non-trivial permutation maps
894  // requires the distribution of the mask to elementwise match the
895  // distribution of the permuted written vector. Currently the details
896  // of which lane is responsible for which element is captured strictly
897  // by shape information on the warp op, and thus requires materializing
898  // the permutation in IR.
899  if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity())
900  return rewriter.notifyMatchFailure(
901  read, "non-trivial permutation maps not supported");
902  VectorType maskType =
903  getDistributedType(read.getMaskType(), map, warpOp.getWarpSize());
904  additionalResults.push_back(read.getMask());
905  additionalResultTypes.push_back(maskType);
906  }
907 
908  SmallVector<size_t> newRetIndices;
909  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
910  rewriter, warpOp, additionalResults, additionalResultTypes,
911  newRetIndices);
912  distributedVal = newWarpOp.getResult(operandIndex);
913 
914  // Distributed indices were appended first.
915  SmallVector<Value> newIndices;
916  for (int64_t i = 0, e = indices.size(); i < e; ++i)
917  newIndices.push_back(newWarpOp.getResult(newRetIndices[i]));
918 
919  rewriter.setInsertionPointAfter(newWarpOp);
920  for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
921  AffineExpr d0, d1;
922  bindDims(read.getContext(), d0, d1);
923  auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it));
924  if (!indexExpr)
925  continue;
926  unsigned indexPos = indexExpr.getPosition();
927  unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition();
928  int64_t scale = distributedType.getDimSize(vectorPos);
929  newIndices[indexPos] = affine::makeComposedAffineApply(
930  rewriter, read.getLoc(), d0 + scale * d1,
931  {newIndices[indexPos], delinearizedIds[vectorPos]});
932  }
933 
934  // Distributed padding value was appended right after the indices.
935  Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]);
936  // Distributed mask value was added at the end (if the op has a mask).
937  Value newMask =
938  hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
939  : Value();
940  auto newRead = rewriter.create<vector::TransferReadOp>(
941  read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
942  read.getPermutationMapAttr(), newPadding, newMask,
943  read.getInBoundsAttr());
944 
945  rewriter.replaceAllUsesWith(distributedVal, newRead);
946  return success();
947  }
948 };
949 
950 /// Remove any result that has no use along with the matching yieldOp operand.
951 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
952 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
954  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
955  PatternRewriter &rewriter) const override {
956  SmallVector<Type> newResultTypes;
957  newResultTypes.reserve(warpOp->getNumResults());
958  SmallVector<Value> newYieldValues;
959  newYieldValues.reserve(warpOp->getNumResults());
960  DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
961  DenseMap<OpResult, int64_t> dedupResultPositionMap;
962  auto yield = cast<vector::YieldOp>(
963  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
964 
965  // Some values may be yielded multiple times and correspond to multiple
966  // results. Deduplicating occurs by taking each result with its matching
967  // yielded value, and:
968  // 1. recording the unique first position at which the value is yielded.
969  // 2. recording for the result, the first position at which the dedup'ed
970  // value is yielded.
971  // 3. skipping from the new result types / new yielded values any result
972  // that has no use or whose yielded value has already been seen.
973  for (OpResult result : warpOp.getResults()) {
974  Value yieldOperand = yield.getOperand(result.getResultNumber());
975  auto it = dedupYieldOperandPositionMap.insert(
976  std::make_pair(yieldOperand, newResultTypes.size()));
977  dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
978  if (result.use_empty() || !it.second)
979  continue;
980  newResultTypes.push_back(result.getType());
981  newYieldValues.push_back(yieldOperand);
982  }
983  // No modification, exit early.
984  if (yield.getNumOperands() == newYieldValues.size())
985  return failure();
986  // Move the body of the old warpOp to a new warpOp.
987  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
988  rewriter, warpOp, newYieldValues, newResultTypes);
989 
990  // Simplify the new warp op after dropping dead results.
991  newWarpOp.getBody()->walk([&](Operation *op) {
992  if (isOpTriviallyDead(op))
993  rewriter.eraseOp(op);
994  });
995 
996  // Replace results of the old warpOp by the new, deduplicated results.
997  SmallVector<Value> newValues;
998  newValues.reserve(warpOp->getNumResults());
999  for (OpResult result : warpOp.getResults()) {
1000  if (result.use_empty())
1001  newValues.push_back(Value());
1002  else
1003  newValues.push_back(
1004  newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
1005  }
1006  rewriter.replaceOp(warpOp, newValues);
1007  return success();
1008  }
1009 };
1010 
1011 // If an operand is directly yielded out of the region we can forward it
1012 // directly and it doesn't need to go through the region.
1013 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
1015  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1016  PatternRewriter &rewriter) const override {
1017  SmallVector<Type> resultTypes;
1018  SmallVector<Value> yieldValues;
1019  auto yield = cast<vector::YieldOp>(
1020  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1021  Value valForwarded;
1022  unsigned resultIndex;
1023  for (OpOperand &operand : yield->getOpOperands()) {
1024  Value result = warpOp.getResult(operand.getOperandNumber());
1025  if (result.use_empty())
1026  continue;
1027 
1028  // Assume all the values coming from above are uniform.
1029  if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
1030  if (result.getType() != operand.get().getType())
1031  continue;
1032  valForwarded = operand.get();
1033  resultIndex = operand.getOperandNumber();
1034  break;
1035  }
1036  auto arg = dyn_cast<BlockArgument>(operand.get());
1037  if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
1038  continue;
1039  Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
1040  if (result.getType() != warpOperand.getType())
1041  continue;
1042  valForwarded = warpOperand;
1043  resultIndex = operand.getOperandNumber();
1044  break;
1045  }
1046  if (!valForwarded)
1047  return failure();
1048  // Notify the rewriter that the warp op is changing (see the comment on
1049  // the WarpOpTransferRead pattern).
1050  rewriter.startOpModification(warpOp);
1051  rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
1052  rewriter.finalizeOpModification(warpOp);
1053  return success();
1054  }
1055 };
1056 
1057 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1059  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1060  PatternRewriter &rewriter) const override {
1061  OpOperand *operand =
1062  getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
1063  if (!operand)
1064  return failure();
1065  unsigned int operandNumber = operand->getOperandNumber();
1066  auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
1067  Location loc = broadcastOp.getLoc();
1068  auto destVecType =
1069  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1070  Value broadcastSrc = broadcastOp.getSource();
1071  Type broadcastSrcType = broadcastSrc.getType();
1072 
1073  // Check that the broadcast actually spans a set of values uniformly across
1074  // all threads. In other words, check that each thread can reconstruct
1075  // their own broadcast.
1076  // For that we simply check that the broadcast we want to build makes sense.
1077  if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
1079  return failure();
1080  SmallVector<size_t> newRetIndices;
1081  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1082  rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
1083  rewriter.setInsertionPointAfter(newWarpOp);
1084  Value broadcasted = rewriter.create<vector::BroadcastOp>(
1085  loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
1086  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1087  broadcasted);
1088  return success();
1089  }
1090 };
1091 
1092 /// Pattern to move shape cast out of the warp op. shape cast is basically a
1093 /// no-op for warp distribution; we need to handle the shape though.
1094 struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
1096  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1097  PatternRewriter &rewriter) const override {
1098  OpOperand *operand =
1099  getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>);
1100  if (!operand)
1101  return failure();
1102 
1103  auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
1104 
1105  unsigned int operandNumber = operand->getOperandNumber();
1106  auto castDistributedType =
1107  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
1108  VectorType castOriginalType = oldCastOp.getSourceVectorType();
1109  VectorType castResultType = castDistributedType;
1110 
1111  // We expect the distributed type to have a smaller rank than the original
1112  // type. Prepend with size-one dimensions to make them the same.
1113  unsigned castDistributedRank = castDistributedType.getRank();
1114  unsigned castOriginalRank = castOriginalType.getRank();
1115  if (castDistributedRank < castOriginalRank) {
1116  SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1117  llvm::append_range(shape, castDistributedType.getShape());
1118  castDistributedType =
1119  VectorType::get(shape, castDistributedType.getElementType());
1120  }
1121 
1122  SmallVector<size_t> newRetIndices;
1123  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1124  rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1125  newRetIndices);
1126  rewriter.setInsertionPointAfter(newWarpOp);
1127  Value newCast = rewriter.create<vector::ShapeCastOp>(
1128  oldCastOp.getLoc(), castResultType,
1129  newWarpOp->getResult(newRetIndices[0]));
1130  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1131  return success();
1132  }
1133 };
1134 
1135 /// Sink out vector.create_mask op feeding into a warp op yield.
1136 /// ```
1137 /// %0 = ...
1138 /// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
1139 /// ...
1140 /// %mask = vector.create_mask %0 : vector<32xi1>
1141 /// vector.yield %mask : vector<32xi1>
1142 /// }
1143 /// ```
1144 /// To
1145 /// ```
1146 /// %0 = ...
1147 /// vector.warp_execute_on_lane_0(%arg0) {
1148 /// ...
1149 /// }
1150 /// %cmp = arith.cmpi ult, %laneid, %0
1151 /// %ub = arith.select %cmp, %c0, %c1
1152 /// %1 = vector.create_mask %ub : vector<1xi1>
1153 struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
1155  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1156  PatternRewriter &rewriter) const override {
1157  OpOperand *yieldOperand =
1158  getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>);
1159  if (!yieldOperand)
1160  return failure();
1161 
1162  auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>();
1163 
1164  // Early exit if any values needed for calculating the new mask indices
1165  // are defined inside the warp op.
1166  if (!llvm::all_of(mask->getOperands(), [&](Value value) {
1167  return warpOp.isDefinedOutsideOfRegion(value);
1168  }))
1169  return failure();
1170 
1171  Location loc = mask.getLoc();
1172  unsigned operandIndex = yieldOperand->getOperandNumber();
1173 
1174  auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType());
1175  VectorType seqType = mask.getVectorType();
1176  ArrayRef<int64_t> seqShape = seqType.getShape();
1177  ArrayRef<int64_t> distShape = distType.getShape();
1178 
1179  rewriter.setInsertionPointAfter(warpOp);
1180 
1181  // Delinearize the lane ID for constructing the distributed mask sizes.
1182  SmallVector<Value> delinearizedIds;
1183  if (!delinearizeLaneId(rewriter, loc, seqShape, distShape,
1184  warpOp.getWarpSize(), warpOp.getLaneid(),
1185  delinearizedIds))
1186  return rewriter.notifyMatchFailure(
1187  mask, "cannot delinearize lane ID for distribution");
1188  assert(!delinearizedIds.empty());
1189 
1190  // Notify the rewriter that the warp op is changing (see the comment on
1191  // the WarpOpTransferRead pattern).
1192  rewriter.startOpModification(warpOp);
1193 
1194  AffineExpr s0, s1;
1195  bindSymbols(rewriter.getContext(), s0, s1);
1196  SmallVector<Value> newOperands;
1197  for (int i = 0, e = distShape.size(); i < e; ++i) {
1198  // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to
1199  // find the distance from the largest mask index owned by this lane to the
1200  // original mask size. `vector.create_mask` implicitly clamps mask
1201  // operands to the range [0, mask_vector_size[i]], or in other words, the
1202  // mask sizes are always in the range [0, mask_vector_size[i]).
1204  rewriter, loc, s1 - s0 * distShape[i],
1205  {delinearizedIds[i], mask.getOperand(i)});
1206  newOperands.push_back(maskDimIdx);
1207  }
1208 
1209  auto newMask =
1210  rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
1211  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
1212  rewriter.finalizeOpModification(warpOp);
1213  return success();
1214  }
1215 };
1216 
1217 /// Pattern to move out vector.extract of single element vector. Those don't
1218 /// need to be distributed and can just be propagated outside of the region.
1219 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
1221  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1222  PatternRewriter &rewriter) const override {
1223  OpOperand *operand =
1224  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
1225  if (!operand)
1226  return failure();
1227  unsigned int operandNumber = operand->getOperandNumber();
1228  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1229  VectorType extractSrcType = extractOp.getSourceVectorType();
1230  Location loc = extractOp.getLoc();
1231 
1232  // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
1233  assert(extractSrcType.getRank() > 0 &&
1234  "vector.extract does not support rank 0 sources");
1235 
1236  // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
1237  // canonicalized to %v.
1238  if (extractOp.getNumIndices() == 0)
1239  return failure();
1240 
1241  // Rewrite vector.extract with 1d source to vector.extractelement.
1242  if (extractSrcType.getRank() == 1) {
1243  if (extractOp.hasDynamicPosition())
1244  // TODO: Dinamic position not supported yet.
1245  return failure();
1246 
1247  assert(extractOp.getNumIndices() == 1 && "expected 1 index");
1248  int64_t pos = extractOp.getStaticPosition()[0];
1249  rewriter.setInsertionPoint(extractOp);
1250  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
1251  extractOp, extractOp.getVector(),
1252  rewriter.create<arith::ConstantIndexOp>(loc, pos));
1253  return success();
1254  }
1255 
1256  // All following cases are 2d or higher dimensional source vectors.
1257 
1258  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1259  // There is no distribution, this is a broadcast. Simply move the extract
1260  // out of the warp op.
1261  // TODO: This could be optimized. E.g., in case of a scalar result, let
1262  // one lane extract and shuffle the result to all other lanes (same as
1263  // the 1d case).
1264  SmallVector<size_t> newRetIndices;
1265  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1266  rewriter, warpOp, {extractOp.getVector()},
1267  {extractOp.getSourceVectorType()}, newRetIndices);
1268  rewriter.setInsertionPointAfter(newWarpOp);
1269  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1270  // Extract from distributed vector.
1271  Value newExtract = rewriter.create<vector::ExtractOp>(
1272  loc, distributedVec, extractOp.getMixedPosition());
1273  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1274  newExtract);
1275  return success();
1276  }
1277 
1278  // Find the distributed dimension. There should be exactly one.
1279  auto distributedType =
1280  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1281  auto yieldedType = cast<VectorType>(operand->get().getType());
1282  int64_t distributedDim = -1;
1283  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1284  if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1285  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1286  // support distributing multiple dimensions in the future.
1287  assert(distributedDim == -1 && "found multiple distributed dims");
1288  distributedDim = i;
1289  }
1290  }
1291  assert(distributedDim != -1 && "could not find distributed dimension");
1292  (void)distributedDim;
1293 
1294  // Yield source vector from warp op.
1295  SmallVector<int64_t> newDistributedShape(extractSrcType.getShape());
1296  for (int i = 0; i < distributedType.getRank(); ++i)
1297  newDistributedShape[i + extractOp.getNumIndices()] =
1298  distributedType.getDimSize(i);
1299  auto newDistributedType =
1300  VectorType::get(newDistributedShape, distributedType.getElementType());
1301  SmallVector<size_t> newRetIndices;
1302  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1303  rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1304  newRetIndices);
1305  rewriter.setInsertionPointAfter(newWarpOp);
1306  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1307  // Extract from distributed vector.
1308  Value newExtract = rewriter.create<vector::ExtractOp>(
1309  loc, distributedVec, extractOp.getMixedPosition());
1310  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1311  newExtract);
1312  return success();
1313  }
1314 };
1315 
1316 /// Pattern to move out vector.extractelement of 0-D tensors. Those don't
1317 /// need to be distributed and can just be propagated outside of the region.
1318 struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1319  WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1320  PatternBenefit b = 1)
1321  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1322  warpShuffleFromIdxFn(std::move(fn)) {}
1323  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1324  PatternRewriter &rewriter) const override {
1325  OpOperand *operand =
1326  getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1327  if (!operand)
1328  return failure();
1329  unsigned int operandNumber = operand->getOperandNumber();
1330  auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1331  VectorType extractSrcType = extractOp.getSourceVectorType();
1332  // TODO: Supported shuffle types should be parameterizable, similar to
1333  // `WarpShuffleFromIdxFn`.
1334  if (!extractSrcType.getElementType().isF32() &&
1335  !extractSrcType.getElementType().isInteger(32))
1336  return rewriter.notifyMatchFailure(
1337  extractOp, "only f32/i32 element types are supported");
1338  bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1339  Type elType = extractSrcType.getElementType();
1340  VectorType distributedVecType;
1341  if (!is0dOrVec1Extract) {
1342  assert(extractSrcType.getRank() == 1 &&
1343  "expected that extractelement src rank is 0 or 1");
1344  if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1345  return failure();
1346  int64_t elementsPerLane =
1347  extractSrcType.getShape()[0] / warpOp.getWarpSize();
1348  distributedVecType = VectorType::get({elementsPerLane}, elType);
1349  } else {
1350  distributedVecType = extractSrcType;
1351  }
1352  // Yield source vector and position (if present) from warp op.
1353  SmallVector<Value> additionalResults{extractOp.getVector()};
1354  SmallVector<Type> additionalResultTypes{distributedVecType};
1355  if (static_cast<bool>(extractOp.getPosition())) {
1356  additionalResults.push_back(extractOp.getPosition());
1357  additionalResultTypes.push_back(extractOp.getPosition().getType());
1358  }
1359  Location loc = extractOp.getLoc();
1360  SmallVector<size_t> newRetIndices;
1361  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1362  rewriter, warpOp, additionalResults, additionalResultTypes,
1363  newRetIndices);
1364  rewriter.setInsertionPointAfter(newWarpOp);
1365  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1366 
1367  // 0d extract: The new warp op broadcasts the source vector to all lanes.
1368  // All lanes extract the scalar.
1369  if (is0dOrVec1Extract) {
1370  Value newExtract;
1371  if (extractSrcType.getRank() == 1) {
1372  newExtract = rewriter.create<vector::ExtractElementOp>(
1373  loc, distributedVec,
1374  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1375 
1376  } else {
1377  newExtract =
1378  rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
1379  }
1380  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1381  newExtract);
1382  return success();
1383  }
1384 
1385  // 1d extract: Distribute the source vector. One lane extracts and shuffles
1386  // the value to all other lanes.
1387  int64_t elementsPerLane = distributedVecType.getShape()[0];
1388  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1389  // tid of extracting thread: pos / elementsPerLane
1390  Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
1391  loc, sym0.ceilDiv(elementsPerLane),
1392  newWarpOp->getResult(newRetIndices[1]));
1393  // Extract at position: pos % elementsPerLane
1394  Value pos =
1395  elementsPerLane == 1
1396  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1397  : rewriter
1398  .create<affine::AffineApplyOp>(
1399  loc, sym0 % elementsPerLane,
1400  newWarpOp->getResult(newRetIndices[1]))
1401  .getResult();
1402  Value extracted =
1403  rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
1404 
1405  // Shuffle the extracted value to all lanes.
1406  Value shuffled = warpShuffleFromIdxFn(
1407  loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1408  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1409  return success();
1410  }
1411 
1412 private:
1413  WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1414 };
1415 
1416 struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1418 
1419  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1420  PatternRewriter &rewriter) const override {
1421  OpOperand *operand =
1422  getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1423  if (!operand)
1424  return failure();
1425  unsigned int operandNumber = operand->getOperandNumber();
1426  auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1427  VectorType vecType = insertOp.getDestVectorType();
1428  VectorType distrType =
1429  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1430  bool hasPos = static_cast<bool>(insertOp.getPosition());
1431 
1432  // Yield destination vector, source scalar and position from warp op.
1433  SmallVector<Value> additionalResults{insertOp.getDest(),
1434  insertOp.getSource()};
1435  SmallVector<Type> additionalResultTypes{distrType,
1436  insertOp.getSource().getType()};
1437  if (hasPos) {
1438  additionalResults.push_back(insertOp.getPosition());
1439  additionalResultTypes.push_back(insertOp.getPosition().getType());
1440  }
1441  Location loc = insertOp.getLoc();
1442  SmallVector<size_t> newRetIndices;
1443  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1444  rewriter, warpOp, additionalResults, additionalResultTypes,
1445  newRetIndices);
1446  rewriter.setInsertionPointAfter(newWarpOp);
1447  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1448  Value newSource = newWarpOp->getResult(newRetIndices[1]);
1449  Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
1450  rewriter.setInsertionPointAfter(newWarpOp);
1451 
1452  if (vecType == distrType) {
1453  // Broadcast: Simply move the vector.inserelement op out.
1454  Value newInsert = rewriter.create<vector::InsertElementOp>(
1455  loc, newSource, distributedVec, newPos);
1456  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1457  newInsert);
1458  return success();
1459  }
1460 
1461  // This is a distribution. Only one lane should insert.
1462  int64_t elementsPerLane = distrType.getShape()[0];
1463  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1464  // tid of extracting thread: pos / elementsPerLane
1465  Value insertingLane = rewriter.create<affine::AffineApplyOp>(
1466  loc, sym0.ceilDiv(elementsPerLane), newPos);
1467  // Insert position: pos % elementsPerLane
1468  Value pos =
1469  elementsPerLane == 1
1470  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1471  : rewriter
1472  .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1473  newPos)
1474  .getResult();
1475  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1476  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1477  Value newResult =
1478  rewriter
1479  .create<scf::IfOp>(
1480  loc, isInsertingLane,
1481  /*thenBuilder=*/
1482  [&](OpBuilder &builder, Location loc) {
1483  Value newInsert = builder.create<vector::InsertElementOp>(
1484  loc, newSource, distributedVec, pos);
1485  builder.create<scf::YieldOp>(loc, newInsert);
1486  },
1487  /*elseBuilder=*/
1488  [&](OpBuilder &builder, Location loc) {
1489  builder.create<scf::YieldOp>(loc, distributedVec);
1490  })
1491  .getResult(0);
1492  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1493  return success();
1494  }
1495 };
1496 
1497 struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
1499 
1500  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1501  PatternRewriter &rewriter) const override {
1502  OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
1503  if (!operand)
1504  return failure();
1505  unsigned int operandNumber = operand->getOperandNumber();
1506  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1507  Location loc = insertOp.getLoc();
1508 
1509  // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
1510  if (insertOp.getNumIndices() == 0)
1511  return failure();
1512 
1513  // Rewrite vector.insert with 1d dest to vector.insertelement.
1514  if (insertOp.getDestVectorType().getRank() == 1) {
1515  if (insertOp.hasDynamicPosition())
1516  // TODO: Dinamic position not supported yet.
1517  return failure();
1518 
1519  assert(insertOp.getNumIndices() == 1 && "expected 1 index");
1520  int64_t pos = insertOp.getStaticPosition()[0];
1521  rewriter.setInsertionPoint(insertOp);
1522  rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
1523  insertOp, insertOp.getSource(), insertOp.getDest(),
1524  rewriter.create<arith::ConstantIndexOp>(loc, pos));
1525  return success();
1526  }
1527 
1528  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1529  // There is no distribution, this is a broadcast. Simply move the insert
1530  // out of the warp op.
1531  SmallVector<size_t> newRetIndices;
1532  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1533  rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1534  {insertOp.getSourceType(), insertOp.getDestVectorType()},
1535  newRetIndices);
1536  rewriter.setInsertionPointAfter(newWarpOp);
1537  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1538  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1539  Value newResult = rewriter.create<vector::InsertOp>(
1540  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1541  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1542  newResult);
1543  return success();
1544  }
1545 
1546  // Find the distributed dimension. There should be exactly one.
1547  auto distrDestType =
1548  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1549  auto yieldedType = cast<VectorType>(operand->get().getType());
1550  int64_t distrDestDim = -1;
1551  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1552  if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1553  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1554  // support distributing multiple dimensions in the future.
1555  assert(distrDestDim == -1 && "found multiple distributed dims");
1556  distrDestDim = i;
1557  }
1558  }
1559  assert(distrDestDim != -1 && "could not find distributed dimension");
1560 
1561  // Compute the distributed source vector type.
1562  VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
1563  SmallVector<int64_t> distrSrcShape(srcVecType.getShape());
1564  // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1565  // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1566  // insert a smaller vector<3xf32>.
1567  // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1568  // case, one lane will insert the source vector<96xf32>. The other
1569  // lanes will not do anything.
1570  int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1571  if (distrSrcDim >= 0)
1572  distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1573  auto distrSrcType =
1574  VectorType::get(distrSrcShape, distrDestType.getElementType());
1575 
1576  // Yield source and dest vectors from warp op.
1577  SmallVector<size_t> newRetIndices;
1578  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1579  rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1580  {distrSrcType, distrDestType}, newRetIndices);
1581  rewriter.setInsertionPointAfter(newWarpOp);
1582  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1583  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1584 
1585  // Insert into the distributed vector.
1586  Value newResult;
1587  if (distrSrcDim >= 0) {
1588  // Every lane inserts a small piece.
1589  newResult = rewriter.create<vector::InsertOp>(
1590  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1591  } else {
1592  // One lane inserts the entire source vector.
1593  int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1594  SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1595  SmallVector<int64_t> newPos = getAsIntegers(pos);
1596  // tid of inserting lane: pos / elementsPerLane
1597  Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1598  loc, newPos[distrDestDim] / elementsPerLane);
1599  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1600  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1601  // Insert position: pos % elementsPerLane
1602  newPos[distrDestDim] %= elementsPerLane;
1603  auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1604  Value newInsert = builder.create<vector::InsertOp>(
1605  loc, distributedSrc, distributedDest, newPos);
1606  builder.create<scf::YieldOp>(loc, newInsert);
1607  };
1608  auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1609  builder.create<scf::YieldOp>(loc, distributedDest);
1610  };
1611  newResult = rewriter
1612  .create<scf::IfOp>(loc, isInsertingLane,
1613  /*thenBuilder=*/insertingBuilder,
1614  /*elseBuilder=*/nonInsertingBuilder)
1615  .getResult(0);
1616  }
1617 
1618  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1619  return success();
1620  }
1621 };
1622 
1623 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1624 /// the scf.ForOp is the last operation in the region so that it doesn't change
1625 /// the order of execution. This creates a new scf.for region after the
1626 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1627 /// WarpExecuteOnLane0Op region. Example:
1628 /// ```
1629 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1630 /// ...
1631 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1632 /// -> (vector<128xf32>) {
1633 /// ...
1634 /// scf.yield %r : vector<128xf32>
1635 /// }
1636 /// vector.yield %v1 : vector<128xf32>
1637 /// }
1638 /// ```
1639 /// To:
1640 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1641 /// ...
1642 /// vector.yield %v : vector<128xf32>
1643 /// }
1644 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1645 /// -> (vector<4xf32>) {
1646 /// %iw = vector.warp_execute_on_lane_0(%laneid)
1647 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1648 /// ^bb0(%arg: vector<128xf32>):
1649 /// ...
1650 /// vector.yield %ir : vector<128xf32>
1651 /// }
1652 /// scf.yield %iw : vector<4xf32>
1653 /// }
1654 /// ```
1655 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
1656 
1657  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1658  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1659  distributionMapFn(std::move(fn)) {}
1661  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1662  PatternRewriter &rewriter) const override {
1663  auto yield = cast<vector::YieldOp>(
1664  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1665  // Only pick up forOp if it is the last op in the region.
1666  Operation *lastNode = yield->getPrevNode();
1667  auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1668  if (!forOp)
1669  return failure();
1670  // Collect Values that come from the warp op but are outside the forOp.
1671  // Those Value needs to be returned by the original warpOp and passed to the
1672  // new op.
1673  llvm::SmallSetVector<Value, 32> escapingValues;
1674  SmallVector<Type> inputTypes;
1675  SmallVector<Type> distTypes;
1677  forOp.getBodyRegion(), [&](OpOperand *operand) {
1678  Operation *parent = operand->get().getParentRegion()->getParentOp();
1679  if (warpOp->isAncestor(parent)) {
1680  if (!escapingValues.insert(operand->get()))
1681  return;
1682  Type distType = operand->get().getType();
1683  if (auto vecType = dyn_cast<VectorType>(distType)) {
1684  AffineMap map = distributionMapFn(operand->get());
1685  distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1686  }
1687  inputTypes.push_back(operand->get().getType());
1688  distTypes.push_back(distType);
1689  }
1690  });
1691 
1692  if (llvm::is_contained(distTypes, Type{}))
1693  return failure();
1694 
1695  SmallVector<size_t> newRetIndices;
1696  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1697  rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1698  newRetIndices);
1699  yield = cast<vector::YieldOp>(
1700  newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1701 
1702  SmallVector<Value> newOperands;
1703  SmallVector<unsigned> resultIdx;
1704  // Collect all the outputs coming from the forOp.
1705  for (OpOperand &yieldOperand : yield->getOpOperands()) {
1706  if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1707  continue;
1708  auto forResult = cast<OpResult>(yieldOperand.get());
1709  newOperands.push_back(
1710  newWarpOp.getResult(yieldOperand.getOperandNumber()));
1711  yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1712  resultIdx.push_back(yieldOperand.getOperandNumber());
1713  }
1714 
1715  OpBuilder::InsertionGuard g(rewriter);
1716  rewriter.setInsertionPointAfter(newWarpOp);
1717 
1718  // Create a new for op outside the region with a WarpExecuteOnLane0Op region
1719  // inside.
1720  auto newForOp = rewriter.create<scf::ForOp>(
1721  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1722  forOp.getStep(), newOperands);
1723  rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
1724 
1725  SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1726  newForOp.getRegionIterArgs().end());
1727  SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1728  forOp.getResultTypes().end());
1729  llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1730  for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1731  warpInput.push_back(newWarpOp.getResult(retIdx));
1732  argIndexMapping[escapingValues[i]] = warpInputType.size();
1733  warpInputType.push_back(inputTypes[i]);
1734  }
1735  auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1736  newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1737  newWarpOp.getWarpSize(), warpInput, warpInputType);
1738 
1739  SmallVector<Value> argMapping;
1740  argMapping.push_back(newForOp.getInductionVar());
1741  for (Value args : innerWarp.getBody()->getArguments()) {
1742  argMapping.push_back(args);
1743  }
1744  argMapping.resize(forOp.getBody()->getNumArguments());
1745  SmallVector<Value> yieldOperands;
1746  for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1747  yieldOperands.push_back(operand);
1748  rewriter.eraseOp(forOp.getBody()->getTerminator());
1749  rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1750  rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
1751  rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
1752  rewriter.setInsertionPointAfter(innerWarp);
1753  if (!innerWarp.getResults().empty())
1754  rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1755  rewriter.eraseOp(forOp);
1756  // Replace the warpOp result coming from the original ForOp.
1757  for (const auto &res : llvm::enumerate(resultIdx)) {
1758  rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1759  newForOp.getResult(res.index()));
1760  newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1761  }
1762  newForOp.walk([&](Operation *op) {
1763  for (OpOperand &operand : op->getOpOperands()) {
1764  auto it = argIndexMapping.find(operand.get());
1765  if (it == argIndexMapping.end())
1766  continue;
1767  operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1768  }
1769  });
1770 
1771  // Finally, hoist out any now uniform code from the inner warp op.
1772  mlir::vector::moveScalarUniformCode(innerWarp);
1773  return success();
1774  }
1775 
1776 private:
1777  DistributionMapFn distributionMapFn;
1778 };
1779 
1780 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1781 /// The vector is reduced in parallel. Currently limited to vector size matching
1782 /// the warpOp size. E.g.:
1783 /// ```
1784 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1785 /// %0 = "some_def"() : () -> (vector<32xf32>)
1786 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1787 /// vector_ext.yield %1 : f32
1788 /// }
1789 /// ```
1790 /// is lowered to:
1791 /// ```
1792 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1793 /// %1 = "some_def"() : () -> (vector<32xf32>)
1794 /// vector_ext.yield %1 : vector<32xf32>
1795 /// }
1796 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
1797 /// %r = ("warp.reduction %a")
1798 /// ```
1799 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
1800  WarpOpReduction(MLIRContext *context,
1801  DistributedReductionFn distributedReductionFn,
1802  PatternBenefit benefit = 1)
1803  : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
1804  distributedReductionFn(std::move(distributedReductionFn)) {}
1805 
1806  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1807  PatternRewriter &rewriter) const override {
1808  OpOperand *yieldOperand =
1809  getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>);
1810  if (!yieldOperand)
1811  return failure();
1812 
1813  auto reductionOp =
1814  cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1815  auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1816  // Only rank 1 vectors supported.
1817  if (vectorType.getRank() != 1)
1818  return rewriter.notifyMatchFailure(
1819  warpOp, "Only rank 1 reductions can be distributed.");
1820  // Only warp_size-sized vectors supported.
1821  if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1822  return rewriter.notifyMatchFailure(
1823  warpOp, "Reduction vector dimension must match was size.");
1824  if (!reductionOp.getType().isIntOrFloat())
1825  return rewriter.notifyMatchFailure(
1826  warpOp, "Reduction distribution currently only supports floats and "
1827  "integer types.");
1828 
1829  int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1830  // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1831  unsigned operandIndex = yieldOperand->getOperandNumber();
1832  SmallVector<Value> yieldValues = {reductionOp.getVector()};
1833  SmallVector<Type> retTypes = {
1834  VectorType::get({numElements}, reductionOp.getType())};
1835  if (reductionOp.getAcc()) {
1836  yieldValues.push_back(reductionOp.getAcc());
1837  retTypes.push_back(reductionOp.getAcc().getType());
1838  }
1839  SmallVector<size_t> newRetIndices;
1840  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1841  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1842  rewriter.setInsertionPointAfter(newWarpOp);
1843 
1844  // Obtain data to reduce for a single lane.
1845  Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1846  // Distribute and reduce across threads.
1847  Value fullReduce =
1848  distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1849  reductionOp.getKind(), newWarpOp.getWarpSize());
1850  if (reductionOp.getAcc()) {
1851  fullReduce = vector::makeArithReduction(
1852  rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1853  newWarpOp.getResult(newRetIndices[1]));
1854  }
1855  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1856  return success();
1857  }
1858 
1859 private:
1860  DistributedReductionFn distributedReductionFn;
1861 };
1862 
1863 } // namespace
1864 
1866  RewritePatternSet &patterns,
1868  patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
1869 }
1870 
1871 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1872  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1873  unsigned maxNumElementsToExtract, PatternBenefit benefit) {
1874  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
1875  maxNumElementsToExtract, benefit);
1876 }
1877 
1878 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1879  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1880  const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1881  PatternBenefit readBenefit) {
1882  patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1883  patterns
1884  .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1885  WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
1886  WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
1887  patterns.getContext(), benefit);
1888  patterns.add<WarpOpExtractElement>(patterns.getContext(),
1889  warpShuffleFromIdxFn, benefit);
1890  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
1891  benefit);
1892 }
1893 
1894 void mlir::vector::populateDistributeReduction(
1895  RewritePatternSet &patterns,
1896  const DistributedReductionFn &distributedReductionFn,
1897  PatternBenefit benefit) {
1898  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
1899  benefit);
1900 }
1901 
1902 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1903  Block *body = warpOp.getBody();
1904 
1905  // Keep track of the ops we want to hoist.
1906  llvm::SmallSetVector<Operation *, 8> opsToMove;
1907 
1908  // Helper to check if a value is or will be defined outside of the region.
1909  auto isDefinedOutsideOfBody = [&](Value value) {
1910  auto *definingOp = value.getDefiningOp();
1911  return (definingOp && opsToMove.count(definingOp)) ||
1912  warpOp.isDefinedOutsideOfRegion(value);
1913  };
1914 
1915  // Do not use walk here, as we do not want to go into nested regions and hoist
1916  // operations from there.
1917  for (auto &op : body->without_terminator()) {
1918  bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1919  return isa<VectorType>(result.getType());
1920  });
1921  if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1922  opsToMove.insert(&op);
1923  }
1924 
1925  // Move all the ops marked as uniform outside of the region.
1926  for (Operation *op : opsToMove)
1927  op->moveBefore(warpOp);
1928 }
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, const std::function< bool(Operation *)> &fn)
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes)
Helper to create a new WarpExecuteOnLane0Op with different signature.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, llvm::SmallVector< size_t > &indices)
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:917
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:964
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:345
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
Operation & front()
Definition: Block.h:151
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:412
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:95
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:555
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1142
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2390
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:314
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Definition: AffineMap.cpp:722
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:40
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:627
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.