MLIR  17.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/AffineExpr.h"
18 #include "llvm/ADT/SetVector.h"
19 #include <utility>
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 
24 /// Currently the distribution map is implicit based on the vector shape. In the
25 /// future it will be part of the op.
26 /// Example:
27 /// ```
28 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
29 /// ...
30 /// vector.yield %3 : vector<32x16x64xf32>
31 /// }
32 /// ```
33 /// Would have an implicit map of:
34 /// `(d0, d1, d2) -> (d0, d2)`
35 static AffineMap calculateImplicitMap(VectorType sequentialType,
36  VectorType distributedType) {
38  perm.reserve(1);
39  // Check which dimensions of the sequential type are different than the
40  // dimensions of the distributed type to know the distributed dimensions. Then
41  // associate each distributed dimension to an ID in order.
42  for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
43  if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
44  perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
45  }
46  auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
47  distributedType.getContext());
48  assert(map.getNumResults() <= 1 &&
49  "only support distribution along one dimension for now.");
50  return map;
51 }
52 
53 namespace {
54 
55 /// Helper struct to create the load / store operations that permit transit
56 /// through the parallel / sequential and the sequential / parallel boundaries
57 /// when performing `rewriteWarpOpToScfFor`.
58 ///
59 /// The vector distribution dimension is inferred from the vector types.
60 struct DistributedLoadStoreHelper {
61  DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
62  Value laneId, Value zero)
63  : sequentialVal(sequentialVal), distributedVal(distributedVal),
64  laneId(laneId), zero(zero) {
65  sequentialVectorType = sequentialVal.getType().dyn_cast<VectorType>();
66  distributedVectorType = distributedVal.getType().dyn_cast<VectorType>();
67  if (sequentialVectorType && distributedVectorType)
68  distributionMap =
69  calculateImplicitMap(sequentialVectorType, distributedVectorType);
70  }
71 
72  Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
73  int64_t distributedSize = distributedVectorType.getDimSize(index);
75  return b.createOrFold<AffineApplyOp>(loc, tid * distributedSize,
76  ArrayRef<Value>{laneId});
77  }
78 
79  /// Create a store during the process of distributing the
80  /// `vector.warp_execute_on_thread_0` op.
81  /// Vector distribution assumes the following convention regarding the
82  /// temporary buffers that are created to transition values. This **must**
83  /// be properly specified in the `options.warpAllocationFn`:
84  /// 1. scalars of type T transit through a memref<1xT>.
85  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
86  Operation *buildStore(RewriterBase &b, Location loc, Value val,
87  Value buffer) {
88  assert((val == distributedVal || val == sequentialVal) &&
89  "Must store either the preregistered distributed or the "
90  "preregistered sequential value.");
91  // Scalar case can directly use memref.store.
92  if (!val.getType().isa<VectorType>())
93  return b.create<memref::StoreOp>(loc, val, buffer, zero);
94 
95  // Vector case must use vector::TransferWriteOp which will later lower to
96  // vector.store of memref.store depending on further lowerings.
97  int64_t rank = sequentialVectorType.getRank();
98  SmallVector<Value> indices(rank, zero);
99  if (val == distributedVal) {
100  for (auto dimExpr : distributionMap.getResults()) {
101  int64_t index = dimExpr.cast<AffineDimExpr>().getPosition();
102  indices[index] = buildDistributedOffset(b, loc, index);
103  }
104  }
105  SmallVector<bool> inBounds(indices.size(), true);
106  return b.create<vector::TransferWriteOp>(
107  loc, val, buffer, indices,
108  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
109  }
110 
111  /// Create a load during the process of distributing the
112  /// `vector.warp_execute_on_thread_0` op.
113  /// Vector distribution assumes the following convention regarding the
114  /// temporary buffers that are created to transition values. This **must**
115  /// be properly specified in the `options.warpAllocationFn`:
116  /// 1. scalars of type T transit through a memref<1xT>.
117  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
118  ///
119  /// When broadcastMode is true, the load is not distributed to account for
120  /// the broadcast semantics of the `vector.warp_execute_on_lane_0` op.
121  ///
122  /// Example:
123  ///
124  /// ```
125  /// %r = vector.warp_execute_on_lane_0(...) -> (f32) {
126  /// vector.yield %cst : f32
127  /// }
128  /// // Both types are f32. The constant %cst is broadcasted to all lanes.
129  /// ```
130  /// This behavior described in more detail in the documentation of the op.
131  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
132 
133  // Scalar case can directly use memref.store.
134  if (!type.isa<VectorType>())
135  return b.create<memref::LoadOp>(loc, buffer, zero);
136 
137  // Other cases must be vector atm.
138  // Vector case must use vector::TransferReadOp which will later lower to
139  // vector.read of memref.read depending on further lowerings.
140  assert((type == distributedVectorType || type == sequentialVectorType) &&
141  "Must store either the preregistered distributed or the "
142  "preregistered sequential type.");
143  SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
144  if (type == distributedVectorType) {
145  for (auto dimExpr : distributionMap.getResults()) {
146  int64_t index = dimExpr.cast<AffineDimExpr>().getPosition();
147  indices[index] = buildDistributedOffset(b, loc, index);
148  }
149  }
150  SmallVector<bool> inBounds(indices.size(), true);
151  return b.create<vector::TransferReadOp>(
152  loc, type.cast<VectorType>(), buffer, indices,
153  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
154  }
155 
156  Value sequentialVal, distributedVal, laneId, zero;
157  VectorType sequentialVectorType, distributedVectorType;
158  AffineMap distributionMap;
159 };
160 
161 } // namespace
162 
163 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
164 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
165  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
166  ValueRange newYieldedValues, TypeRange newReturnTypes) {
167  // Create a new op before the existing one, with the extra operands.
168  OpBuilder::InsertionGuard g(rewriter);
169  rewriter.setInsertionPoint(warpOp);
170  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
171  warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
172  warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
173 
174  Region &opBody = warpOp.getBodyRegion();
175  Region &newOpBody = newWarpOp.getBodyRegion();
176  Block &newOpFirstBlock = newOpBody.front();
177  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
178  rewriter.eraseBlock(&newOpFirstBlock);
179  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
180  "expected WarpOp with single block");
181 
182  auto yield =
183  cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
184 
185  rewriter.updateRootInPlace(
186  yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
187  return newWarpOp;
188 }
189 
190 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
191 /// `indices` return the index of each new output.
192 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
193  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
194  ValueRange newYieldedValues, TypeRange newReturnTypes,
195  llvm::SmallVector<size_t> &indices) {
196  SmallVector<Type> types(warpOp.getResultTypes().begin(),
197  warpOp.getResultTypes().end());
198  auto yield = cast<vector::YieldOp>(
199  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
200  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
201  yield.getOperands().end());
202  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
203  if (yieldValues.insert(std::get<0>(newRet))) {
204  types.push_back(std::get<1>(newRet));
205  indices.push_back(yieldValues.size() - 1);
206  } else {
207  // If the value already exit the region don't create a new output.
208  for (auto [idx, yieldOperand] :
209  llvm::enumerate(yieldValues.getArrayRef())) {
210  if (yieldOperand == std::get<0>(newRet)) {
211  indices.push_back(idx);
212  break;
213  }
214  }
215  }
216  }
217  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
218  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
219  rewriter, warpOp, yieldValues.getArrayRef(), types);
220  rewriter.replaceOp(warpOp,
221  newWarpOp.getResults().take_front(warpOp.getNumResults()));
222  return newWarpOp;
223 }
224 
225 /// Helper to know if an op can be hoisted out of the region.
226 static bool canBeHoisted(Operation *op,
227  function_ref<bool(Value)> definedOutside) {
228  return llvm::all_of(op->getOperands(), definedOutside) &&
229  isMemoryEffectFree(op) && op->getNumRegions() == 0;
230 }
231 
232 /// Return a value yielded by `warpOp` which statifies the filter lamdba
233 /// condition and is not dead.
234 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
235  const std::function<bool(Operation *)> &fn) {
236  auto yield = cast<vector::YieldOp>(
237  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
238  for (OpOperand &yieldOperand : yield->getOpOperands()) {
239  Value yieldValues = yieldOperand.get();
240  Operation *definedOp = yieldValues.getDefiningOp();
241  if (definedOp && fn(definedOp)) {
242  if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
243  return &yieldOperand;
244  }
245  }
246  return {};
247 }
248 
249 // Clones `op` into a new operation that takes `operands` and returns
250 // `resultTypes`.
252  Location loc, Operation *op,
253  ArrayRef<Value> operands,
254  ArrayRef<Type> resultTypes) {
255  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
256  op->getAttrs());
257  return rewriter.create(res);
258 }
259 
260 namespace {
261 
262 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
263 /// thread `laneId` executes the entirety of the computation.
264 ///
265 /// After the transformation:
266 /// - the IR within the scf.if op can be thought of as executing sequentially
267 /// (from the point of view of threads along `laneId`).
268 /// - the IR outside of the scf.if op can be thought of as executing in
269 /// parallel (from the point of view of threads along `laneId`).
270 ///
271 /// Values that need to transit through the parallel / sequential and the
272 /// sequential / parallel boundaries do so via reads and writes to a temporary
273 /// memory location.
274 ///
275 /// The transformation proceeds in multiple steps:
276 /// 1. Create the scf.if op.
277 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
278 /// within the scf.if to transit the values captured from above.
279 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
280 /// consistent within the scf.if.
281 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
282 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to
283 /// transit the values returned by the op.
284 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
285 /// consistent after the scf.if.
286 /// 7. Perform late cleanups.
287 ///
288 /// All this assumes the vector distribution occurs along the most minor
289 /// distributed vector dimension.
290 struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
291  WarpOpToScfIfPattern(MLIRContext *context,
293  PatternBenefit benefit = 1)
294  : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
295  options(options) {}
296 
297  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
298  PatternRewriter &rewriter) const override {
299  assert(warpOp.getBodyRegion().hasOneBlock() &&
300  "expected WarpOp with single block");
301  Block *warpOpBody = &warpOp.getBodyRegion().front();
302  Location loc = warpOp.getLoc();
303 
304  // Passed all checks. Start rewriting.
305  OpBuilder::InsertionGuard g(rewriter);
306  rewriter.setInsertionPoint(warpOp);
307 
308  // Step 1: Create scf.if op.
309  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
310  Value isLane0 = rewriter.create<arith::CmpIOp>(
311  loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
312  auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
313  /*withElseRegion=*/false);
314  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
315 
316  // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
317  // reads within the scf.if to transit the values captured from above.
318  SmallVector<Value> bbArgReplacements;
319  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
320  Value sequentialVal = warpOpBody->getArgument(it.index());
321  Value distributedVal = it.value();
322  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
323  warpOp.getLaneid(), c0);
324 
325  // Create buffer before the ifOp.
326  rewriter.setInsertionPoint(ifOp);
327  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
328  sequentialVal.getType());
329  // Store distributed vector into buffer, before the ifOp.
330  helper.buildStore(rewriter, loc, distributedVal, buffer);
331  // Load sequential vector from buffer, inside the ifOp.
332  rewriter.setInsertionPointToStart(ifOp.thenBlock());
333  bbArgReplacements.push_back(
334  helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
335  }
336 
337  // Step 3. Insert sync after all the stores and before all the loads.
338  if (!warpOp.getArgs().empty()) {
339  rewriter.setInsertionPoint(ifOp);
340  options.warpSyncronizationFn(loc, rewriter, warpOp);
341  }
342 
343  // Step 4. Move body of warpOp to ifOp.
344  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
345 
346  // Step 5. Insert appropriate writes within scf.if and reads after the
347  // scf.if to transit the values returned by the op.
348  // TODO: at this point, we can reuse the shared memory from previous
349  // buffers.
350  SmallVector<Value> replacements;
351  auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
352  Location yieldLoc = yieldOp.getLoc();
353  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
354  Value sequentialVal = it.value();
355  Value distributedVal = warpOp->getResult(it.index());
356  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
357  warpOp.getLaneid(), c0);
358 
359  // Create buffer before the ifOp.
360  rewriter.setInsertionPoint(ifOp);
361  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
362  sequentialVal.getType());
363 
364  // Store yielded value into buffer, inside the ifOp, before the
365  // terminator.
366  rewriter.setInsertionPoint(yieldOp);
367  helper.buildStore(rewriter, loc, sequentialVal, buffer);
368 
369  // Load distributed value from buffer, after the warpOp.
370  rewriter.setInsertionPointAfter(ifOp);
371  // Result type and yielded value type are the same. This is a broadcast.
372  // E.g.:
373  // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
374  // vector.yield %cst : f32
375  // }
376  // Both types are f32. The constant %cst is broadcasted to all lanes.
377  // This is described in more detail in the documentation of the op.
378  replacements.push_back(
379  helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
380  }
381 
382  // Step 6. Insert sync after all the stores and before all the loads.
383  if (!yieldOp.getOperands().empty()) {
384  rewriter.setInsertionPointAfter(ifOp);
385  options.warpSyncronizationFn(loc, rewriter, warpOp);
386  }
387 
388  // Step 7. Delete terminator and add empty scf.yield.
389  rewriter.eraseOp(yieldOp);
390  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
391  rewriter.create<scf::YieldOp>(yieldLoc);
392 
393  // Compute replacements for WarpOp results.
394  rewriter.replaceOp(warpOp, replacements);
395 
396  return success();
397  }
398 
399 private:
401 };
402 
403 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
404 /// op with the proper return type.
405 /// The new write op is updated to write the result of the new warp execute op.
406 /// The old `writeOp` is deleted.
407 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
408  WarpExecuteOnLane0Op warpOp,
409  vector::TransferWriteOp writeOp,
410  VectorType targetType) {
411  assert(writeOp->getParentOp() == warpOp &&
412  "write must be nested immediately under warp");
413  OpBuilder::InsertionGuard g(rewriter);
414  SmallVector<size_t> newRetIndices;
415  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
416  rewriter, warpOp, ValueRange{{writeOp.getVector()}},
417  TypeRange{targetType}, newRetIndices);
418  rewriter.setInsertionPointAfter(newWarpOp);
419  auto newWriteOp =
420  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
421  rewriter.eraseOp(writeOp);
422  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
423  return newWriteOp;
424 }
425 
426 /// Return the distributed vector type based on the original type and the
427 /// distribution map. The map is expected to have a dimension equal to the
428 /// original type rank and should be a projection where the results are the
429 /// distributed dimensions. The number of results should be equal to the number
430 /// of warp sizes which is currently limited to 1.
431 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
432 /// and a warp size of 16 would distribute the second dimension (associated to
433 /// d1) and return vector<16x2x64>
434 static VectorType getDistributedType(VectorType originalType, AffineMap map,
435  int64_t warpSize) {
436  if (map.getNumResults() != 1)
437  return VectorType();
438  SmallVector<int64_t> targetShape(originalType.getShape().begin(),
439  originalType.getShape().end());
440  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
441  unsigned position = map.getDimPosition(i);
442  if (targetShape[position] % warpSize != 0)
443  return VectorType();
444  targetShape[position] = targetShape[position] / warpSize;
445  }
446  VectorType targetType =
447  VectorType::get(targetShape, originalType.getElementType());
448  return targetType;
449 }
450 
451 /// Distribute transfer_write ops based on the affine map returned by
452 /// `distributionMapFn`.
453 /// Example:
454 /// ```
455 /// %0 = vector.warp_execute_on_lane_0(%id){
456 /// ...
457 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
458 /// vector.yield
459 /// }
460 /// ```
461 /// To
462 /// ```
463 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
464 /// ...
465 /// vector.yield %v : vector<32xf32>
466 /// }
467 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
468 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
469  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
470  PatternBenefit b = 1)
471  : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
472  distributionMapFn(std::move(fn)) {}
473 
474  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
475  /// are multiples of the distribution ratio are supported at the moment.
476  LogicalResult tryDistributeOp(RewriterBase &rewriter,
477  vector::TransferWriteOp writeOp,
478  WarpExecuteOnLane0Op warpOp) const {
479  VectorType writtenVectorType = writeOp.getVectorType();
480 
481  // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
482  // to separate it from the rest.
483  if (writtenVectorType.getRank() == 0)
484  return failure();
485 
486  // 2. Compute the distributed type.
487  AffineMap map = distributionMapFn(writeOp.getVector());
488  VectorType targetType =
489  getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
490  if (!targetType)
491  return failure();
492 
493  // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
494  // the rest.
495  vector::TransferWriteOp newWriteOp =
496  cloneWriteOp(rewriter, warpOp, writeOp, targetType);
497 
498  // 4. Reindex the write using the distribution map.
499  auto newWarpOp =
500  newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
501  rewriter.setInsertionPoint(newWriteOp);
502  AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
503  Location loc = newWriteOp.getLoc();
504  SmallVector<Value> indices(newWriteOp.getIndices().begin(),
505  newWriteOp.getIndices().end());
506  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
507  AffineExpr d0, d1;
508  bindDims(newWarpOp.getContext(), d0, d1);
509  auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
510  if (!indexExpr)
511  continue;
512  unsigned indexPos = indexExpr.getPosition();
513  unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
514  auto scale =
515  rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
516  indices[indexPos] =
517  makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
518  {indices[indexPos], newWarpOp.getLaneid()});
519  }
520  newWriteOp.getIndicesMutable().assign(indices);
521 
522  return success();
523  }
524 
525  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
526  LogicalResult tryExtractOp(RewriterBase &rewriter,
527  vector::TransferWriteOp writeOp,
528  WarpExecuteOnLane0Op warpOp) const {
529  Location loc = writeOp.getLoc();
530  VectorType vecType = writeOp.getVectorType();
531 
532  // Only sink out vector of 1 element for now to not serialize large vector
533  // store. This can later be controlled by user.
534  if (vecType.getNumElements() != 1)
535  return failure();
536 
537  // Do not process warp ops that contain only TransferWriteOps.
538  if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
539  return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
540  }))
541  return failure();
542 
543  SmallVector<Value> yieldValues = {writeOp.getVector()};
544  SmallVector<Type> retTypes = {vecType};
545  SmallVector<size_t> newRetIndices;
546  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
547  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
548  rewriter.setInsertionPointAfter(newWarpOp);
549 
550  // Create a second warp op that contains only writeOp.
551  auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
552  loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
553  Block &body = secondWarpOp.getBodyRegion().front();
554  rewriter.setInsertionPointToStart(&body);
555  auto newWriteOp =
556  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
557  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
558  rewriter.eraseOp(writeOp);
559  rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
560  return success();
561  }
562 
563  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
564  PatternRewriter &rewriter) const override {
565  // Ops with mask not supported yet.
566  if (writeOp.getMask())
567  return failure();
568 
569  auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
570  if (!warpOp)
571  return failure();
572 
573  // There must be no op with a side effect after writeOp.
574  Operation *nextOp = writeOp.getOperation();
575  while ((nextOp = nextOp->getNextNode()))
576  if (!isMemoryEffectFree(nextOp))
577  return failure();
578 
579  if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
580  return writeOp.getVector() == value ||
581  warpOp.isDefinedOutsideOfRegion(value);
582  }))
583  return failure();
584 
585  if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
586  return success();
587 
588  if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
589  return success();
590 
591  return failure();
592  }
593 
594 private:
595  DistributionMapFn distributionMapFn;
596 };
597 
598 /// Sink out elementwise op feeding into a warp op yield.
599 /// ```
600 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
601 /// ...
602 /// %3 = arith.addf %1, %2 : vector<32xf32>
603 /// vector.yield %3 : vector<32xf32>
604 /// }
605 /// ```
606 /// To
607 /// ```
608 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
609 /// vector<1xf32>, vector<1xf32>) {
610 /// ...
611 /// %4 = arith.addf %2, %3 : vector<32xf32>
612 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
613 /// vector<32xf32>
614 /// }
615 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
616 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
618  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
619  PatternRewriter &rewriter) const override {
620  OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
622  });
623  if (!yieldOperand)
624  return failure();
625  Operation *elementWise = yieldOperand->get().getDefiningOp();
626  unsigned operandIndex = yieldOperand->getOperandNumber();
627  Value distributedVal = warpOp.getResult(operandIndex);
628  SmallVector<Value> yieldValues;
629  SmallVector<Type> retTypes;
630  Location loc = warpOp.getLoc();
631  for (OpOperand &operand : elementWise->getOpOperands()) {
632  Type targetType;
633  if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
634  // If the result type is a vector, the operands must also be vectors.
635  auto operandType = operand.get().getType().cast<VectorType>();
636  targetType =
637  VectorType::get(vecType.getShape(), operandType.getElementType());
638  } else {
639  auto operandType = operand.get().getType();
640  assert(!operandType.isa<VectorType>() &&
641  "unexpected yield of vector from op with scalar result type");
642  targetType = operandType;
643  }
644  retTypes.push_back(targetType);
645  yieldValues.push_back(operand.get());
646  }
647  SmallVector<size_t> newRetIndices;
648  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
649  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
650  rewriter.setInsertionPointAfter(newWarpOp);
651  SmallVector<Value> newOperands(elementWise->getOperands().begin(),
652  elementWise->getOperands().end());
653  for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
654  newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
655  }
656  OpBuilder::InsertionGuard g(rewriter);
657  rewriter.setInsertionPointAfter(newWarpOp);
659  rewriter, loc, elementWise, newOperands,
660  {newWarpOp.getResult(operandIndex).getType()});
661  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
662  newOp->getResult(0));
663  return success();
664  }
665 };
666 
667 /// Sink out splat constant op feeding into a warp op yield.
668 /// ```
669 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
670 /// ...
671 /// %cst = arith.constant dense<2.0> : vector<32xf32>
672 /// vector.yield %cst : vector<32xf32>
673 /// }
674 /// ```
675 /// To
676 /// ```
677 /// vector.warp_execute_on_lane_0(%arg0 {
678 /// ...
679 /// }
680 /// %0 = arith.constant dense<2.0> : vector<1xf32>
681 struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
683  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
684  PatternRewriter &rewriter) const override {
685  OpOperand *yieldOperand = getWarpResult(
686  warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); });
687  if (!yieldOperand)
688  return failure();
689  auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
690  auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
691  if (!dense)
692  return failure();
693  unsigned operandIndex = yieldOperand->getOperandNumber();
694  Attribute scalarAttr = dense.getSplatValue<Attribute>();
696  warpOp.getResult(operandIndex).getType(), scalarAttr);
697  Location loc = warpOp.getLoc();
698  rewriter.setInsertionPointAfter(warpOp);
699  Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
700  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
701  return success();
702  }
703 };
704 
705 /// Sink out transfer_read op feeding into a warp op yield.
706 /// ```
707 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
708 /// ...
709 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
710 // vector<32xf32>
711 /// vector.yield %2 : vector<32xf32>
712 /// }
713 /// ```
714 /// To
715 /// ```
716 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
717 /// vector<1xf32>, vector<1xf32>) {
718 /// ...
719 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
720 /// vector<32xf32> vector.yield %2 : vector<32xf32>
721 /// }
722 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
723 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
725  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
726  PatternRewriter &rewriter) const override {
727  OpOperand *operand = getWarpResult(
728  warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
729  if (!operand)
730  return failure();
731  auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
732  // Don't duplicate transfer_read ops when distributing.
733  if (!read.getResult().hasOneUse())
734  return failure();
735  unsigned operandIndex = operand->getOperandNumber();
736  Value distributedVal = warpOp.getResult(operandIndex);
737 
738  SmallVector<Value, 4> indices(read.getIndices().begin(),
739  read.getIndices().end());
740  auto sequentialType = read.getResult().getType().cast<VectorType>();
741  auto distributedType = distributedVal.getType().cast<VectorType>();
742  AffineMap map = calculateImplicitMap(sequentialType, distributedType);
743  AffineMap indexMap = map.compose(read.getPermutationMap());
744  OpBuilder::InsertionGuard g(rewriter);
745  rewriter.setInsertionPointAfter(warpOp);
746  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
747  AffineExpr d0, d1;
748  bindDims(read.getContext(), d0, d1);
749  auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
750  if (!indexExpr)
751  continue;
752  unsigned indexPos = indexExpr.getPosition();
753  unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
754  int64_t scale =
755  distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
756  indices[indexPos] =
757  makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
758  {indices[indexPos], warpOp.getLaneid()});
759  }
760  Value newRead = rewriter.create<vector::TransferReadOp>(
761  read.getLoc(), distributedVal.getType(), read.getSource(), indices,
762  read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
763  read.getInBoundsAttr());
764  rewriter.replaceAllUsesWith(distributedVal, newRead);
765  return success();
766  }
767 };
768 
769 /// Remove any result that has no use along with the matching yieldOp operand.
770 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
771 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
773  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
774  PatternRewriter &rewriter) const override {
775  SmallVector<Type> newResultTypes;
776  newResultTypes.reserve(warpOp->getNumResults());
777  SmallVector<Value> newYieldValues;
778  newYieldValues.reserve(warpOp->getNumResults());
779  DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
780  DenseMap<OpResult, int64_t> dedupResultPositionMap;
781  auto yield = cast<vector::YieldOp>(
782  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
783 
784  // Some values may be yielded multiple times and correspond to multiple
785  // results. Deduplicating occurs by taking each result with its matching
786  // yielded value, and:
787  // 1. recording the unique first position at which the value is yielded.
788  // 2. recording for the result, the first position at which the dedup'ed
789  // value is yielded.
790  // 3. skipping from the new result types / new yielded values any result
791  // that has no use or whose yielded value has already been seen.
792  for (OpResult result : warpOp.getResults()) {
793  Value yieldOperand = yield.getOperand(result.getResultNumber());
794  auto it = dedupYieldOperandPositionMap.insert(
795  std::make_pair(yieldOperand, newResultTypes.size()));
796  dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
797  if (result.use_empty() || !it.second)
798  continue;
799  newResultTypes.push_back(result.getType());
800  newYieldValues.push_back(yieldOperand);
801  }
802  // No modification, exit early.
803  if (yield.getNumOperands() == newYieldValues.size())
804  return failure();
805  // Move the body of the old warpOp to a new warpOp.
806  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
807  rewriter, warpOp, newYieldValues, newResultTypes);
808  // Replace results of the old warpOp by the new, deduplicated results.
809  SmallVector<Value> newValues;
810  newValues.reserve(warpOp->getNumResults());
811  for (OpResult result : warpOp.getResults()) {
812  if (result.use_empty())
813  newValues.push_back(Value());
814  else
815  newValues.push_back(
816  newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
817  }
818  rewriter.replaceOp(warpOp, newValues);
819  return success();
820  }
821 };
822 
823 // If an operand is directly yielded out of the region we can forward it
824 // directly and it doesn't need to go through the region.
825 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
827  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
828  PatternRewriter &rewriter) const override {
829  SmallVector<Type> resultTypes;
830  SmallVector<Value> yieldValues;
831  auto yield = cast<vector::YieldOp>(
832  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
833  Value valForwarded;
834  unsigned resultIndex;
835  for (OpOperand &operand : yield->getOpOperands()) {
836  Value result = warpOp.getResult(operand.getOperandNumber());
837  if (result.use_empty())
838  continue;
839 
840  // Assume all the values coming from above are uniform.
841  if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
842  if (result.getType() != operand.get().getType())
843  continue;
844  valForwarded = operand.get();
845  resultIndex = operand.getOperandNumber();
846  break;
847  }
848  auto arg = operand.get().dyn_cast<BlockArgument>();
849  if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
850  continue;
851  Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
852  if (result.getType() != warpOperand.getType())
853  continue;
854  valForwarded = warpOperand;
855  resultIndex = operand.getOperandNumber();
856  break;
857  }
858  if (!valForwarded)
859  return failure();
860  rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
861  return success();
862  }
863 };
864 
865 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
867  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
868  PatternRewriter &rewriter) const override {
869  OpOperand *operand = getWarpResult(
870  warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
871  if (!operand)
872  return failure();
873  unsigned int operandNumber = operand->getOperandNumber();
874  auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
875  Location loc = broadcastOp.getLoc();
876  auto destVecType =
877  warpOp->getResultTypes()[operandNumber].cast<VectorType>();
878  SmallVector<size_t> newRetIndices;
879  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
880  rewriter, warpOp, {broadcastOp.getSource()},
881  {broadcastOp.getSource().getType()}, newRetIndices);
882  rewriter.setInsertionPointAfter(newWarpOp);
883  Value broadcasted = rewriter.create<vector::BroadcastOp>(
884  loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
885  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
886  broadcasted);
887  return success();
888  }
889 };
890 
891 /// Pattern to move out vector.extract of single element vector. Those don't
892 /// need to be distributed and can just be propagated outside of the region.
893 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
895  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
896  PatternRewriter &rewriter) const override {
897  OpOperand *operand = getWarpResult(
898  warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
899  if (!operand)
900  return failure();
901  unsigned int operandNumber = operand->getOperandNumber();
902  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
903  VectorType extractSrcType = extractOp.getSourceVectorType();
904  Location loc = extractOp.getLoc();
905 
906  // "vector.extract %v[] : vector<f32>" is an invalid op.
907  assert(extractSrcType.getRank() > 0 &&
908  "vector.extract does not support rank 0 sources");
909 
910  // "vector.extract %v[] : vector<...xf32>" can be canonicalized to %v.
911  if (extractOp.getPosition().empty())
912  return failure();
913 
914  // Rewrite vector.extract with 1d source to vector.extractelement.
915  if (extractSrcType.getRank() == 1) {
916  assert(extractOp.getPosition().size() == 1 && "expected 1 index");
917  int64_t pos = extractOp.getPosition()[0].cast<IntegerAttr>().getInt();
918  rewriter.setInsertionPoint(extractOp);
919  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
920  extractOp, extractOp.getVector(),
921  rewriter.create<arith::ConstantIndexOp>(loc, pos));
922  return success();
923  }
924 
925  // All following cases are 2d or higher dimensional source vectors.
926 
927  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
928  // There is no distribution, this is a broadcast. Simply move the extract
929  // out of the warp op.
930  // TODO: This could be optimized. E.g., in case of a scalar result, let
931  // one lane extract and shuffle the result to all other lanes (same as
932  // the 1d case).
933  SmallVector<size_t> newRetIndices;
934  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
935  rewriter, warpOp, {extractOp.getVector()},
936  {extractOp.getSourceVectorType()}, newRetIndices);
937  rewriter.setInsertionPointAfter(newWarpOp);
938  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
939  // Extract from distributed vector.
940  Value newExtract = rewriter.create<vector::ExtractOp>(
941  loc, distributedVec, extractOp.getPosition());
942  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
943  newExtract);
944  return success();
945  }
946 
947  // Find the distributed dimension. There should be exactly one.
948  auto distributedType =
949  warpOp.getResult(operandNumber).getType().cast<VectorType>();
950  auto yieldedType = operand->get().getType().cast<VectorType>();
951  int64_t distributedDim = -1;
952  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
953  if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
954  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
955  // support distributing multiple dimensions in the future.
956  assert(distributedDim == -1 && "found multiple distributed dims");
957  distributedDim = i;
958  }
959  }
960  assert(distributedDim != -1 && "could not find distributed dimension");
961  (void)distributedDim;
962 
963  // Yield source vector from warp op.
964  SmallVector<int64_t> newDistributedShape(extractSrcType.getShape().begin(),
965  extractSrcType.getShape().end());
966  for (int i = 0; i < distributedType.getRank(); ++i)
967  newDistributedShape[i + extractOp.getPosition().size()] =
968  distributedType.getDimSize(i);
969  auto newDistributedType =
970  VectorType::get(newDistributedShape, distributedType.getElementType());
971  SmallVector<size_t> newRetIndices;
972  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
973  rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
974  newRetIndices);
975  rewriter.setInsertionPointAfter(newWarpOp);
976  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
977  // Extract from distributed vector.
978  Value newExtract = rewriter.create<vector::ExtractOp>(
979  loc, distributedVec, extractOp.getPosition());
980  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
981  newExtract);
982  return success();
983  }
984 };
985 
986 /// Pattern to move out vector.extractelement of 0-D tensors. Those don't
987 /// need to be distributed and can just be propagated outside of the region.
988 struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
989  WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
990  PatternBenefit b = 1)
991  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
992  warpShuffleFromIdxFn(std::move(fn)) {}
993  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
994  PatternRewriter &rewriter) const override {
995  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
996  return isa<vector::ExtractElementOp>(op);
997  });
998  if (!operand)
999  return failure();
1000  unsigned int operandNumber = operand->getOperandNumber();
1001  auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1002  VectorType extractSrcType = extractOp.getSourceVectorType();
1003  bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1004  Type elType = extractSrcType.getElementType();
1005  VectorType distributedVecType;
1006  if (!is0dOrVec1Extract) {
1007  assert(extractSrcType.getRank() == 1 &&
1008  "expected that extractelement src rank is 0 or 1");
1009  if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1010  return failure();
1011  int64_t elementsPerLane =
1012  extractSrcType.getShape()[0] / warpOp.getWarpSize();
1013  distributedVecType = VectorType::get({elementsPerLane}, elType);
1014  } else {
1015  distributedVecType = extractSrcType;
1016  }
1017  // Yield source vector from warp op.
1018  Location loc = extractOp.getLoc();
1019  SmallVector<size_t> newRetIndices;
1020  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1021  rewriter, warpOp, {extractOp.getVector()}, {distributedVecType},
1022  newRetIndices);
1023  rewriter.setInsertionPointAfter(newWarpOp);
1024  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1025 
1026  // 0d extract: The new warp op broadcasts the source vector to all lanes.
1027  // All lanes extract the scalar.
1028  if (is0dOrVec1Extract) {
1029  Value newExtract;
1030  if (extractSrcType.getRank() == 1) {
1031  newExtract = rewriter.create<vector::ExtractElementOp>(
1032  loc, distributedVec,
1033  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1034 
1035  } else {
1036  newExtract =
1037  rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
1038  }
1039  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1040  newExtract);
1041  return success();
1042  }
1043 
1044  // 1d extract: Distribute the source vector. One lane extracts and shuffles
1045  // the value to all other lanes.
1046  int64_t elementsPerLane = distributedVecType.getShape()[0];
1047  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1048  // tid of extracting thread: pos / elementsPerLane
1049  Value broadcastFromTid = rewriter.create<AffineApplyOp>(
1050  loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition());
1051  // Extract at position: pos % elementsPerLane
1052  Value pos =
1053  elementsPerLane == 1
1054  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1055  : rewriter
1056  .create<AffineApplyOp>(loc, sym0 % elementsPerLane,
1057  extractOp.getPosition())
1058  .getResult();
1059  Value extracted =
1060  rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
1061 
1062  // Shuffle the extracted value to all lanes.
1063  Value shuffled = warpShuffleFromIdxFn(
1064  loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1065  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1066  return success();
1067  }
1068 
1069 private:
1070  WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1071 };
1072 
1073 struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1075 
1076  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1077  PatternRewriter &rewriter) const override {
1078  OpOperand *operand = getWarpResult(
1079  warpOp, [](Operation *op) { return isa<vector::InsertElementOp>(op); });
1080  if (!operand)
1081  return failure();
1082  unsigned int operandNumber = operand->getOperandNumber();
1083  auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1084  VectorType vecType = insertOp.getDestVectorType();
1085  VectorType distrType =
1086  warpOp.getResult(operandNumber).getType().cast<VectorType>();
1087  bool hasPos = static_cast<bool>(insertOp.getPosition());
1088 
1089  // Yield destination vector, source scalar and position from warp op.
1090  SmallVector<Value> additionalResults{insertOp.getDest(),
1091  insertOp.getSource()};
1092  SmallVector<Type> additionalResultTypes{distrType,
1093  insertOp.getSource().getType()};
1094  if (hasPos) {
1095  additionalResults.push_back(insertOp.getPosition());
1096  additionalResultTypes.push_back(insertOp.getPosition().getType());
1097  }
1098  Location loc = insertOp.getLoc();
1099  SmallVector<size_t> newRetIndices;
1100  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1101  rewriter, warpOp, additionalResults, additionalResultTypes,
1102  newRetIndices);
1103  rewriter.setInsertionPointAfter(newWarpOp);
1104  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1105  Value newSource = newWarpOp->getResult(newRetIndices[1]);
1106  Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
1107  rewriter.setInsertionPointAfter(newWarpOp);
1108 
1109  if (vecType == distrType) {
1110  // Broadcast: Simply move the vector.inserelement op out.
1111  Value newInsert = rewriter.create<vector::InsertElementOp>(
1112  loc, newSource, distributedVec, newPos);
1113  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1114  newInsert);
1115  return success();
1116  }
1117 
1118  // This is a distribution. Only one lane should insert.
1119  int64_t elementsPerLane = distrType.getShape()[0];
1120  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1121  // tid of extracting thread: pos / elementsPerLane
1122  Value insertingLane = rewriter.create<AffineApplyOp>(
1123  loc, sym0.ceilDiv(elementsPerLane), newPos);
1124  // Insert position: pos % elementsPerLane
1125  Value pos =
1126  elementsPerLane == 1
1127  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1128  : rewriter
1129  .create<AffineApplyOp>(loc, sym0 % elementsPerLane, newPos)
1130  .getResult();
1131  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1132  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1133  Value newResult =
1134  rewriter
1135  .create<scf::IfOp>(
1136  loc, isInsertingLane,
1137  /*thenBuilder=*/
1138  [&](OpBuilder &builder, Location loc) {
1139  Value newInsert = builder.create<vector::InsertElementOp>(
1140  loc, newSource, distributedVec, pos);
1141  builder.create<scf::YieldOp>(loc, newInsert);
1142  },
1143  /*elseBuilder=*/
1144  [&](OpBuilder &builder, Location loc) {
1145  builder.create<scf::YieldOp>(loc, distributedVec);
1146  })
1147  .getResult(0);
1148  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1149  return success();
1150  }
1151 };
1152 
1153 struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
1155 
1156  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1157  PatternRewriter &rewriter) const override {
1158  OpOperand *operand = getWarpResult(
1159  warpOp, [](Operation *op) { return isa<vector::InsertOp>(op); });
1160  if (!operand)
1161  return failure();
1162  unsigned int operandNumber = operand->getOperandNumber();
1163  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1164  Location loc = insertOp.getLoc();
1165 
1166  // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
1167  if (insertOp.getPosition().empty())
1168  return failure();
1169 
1170  // Rewrite vector.insert with 1d dest to vector.insertelement.
1171  if (insertOp.getDestVectorType().getRank() == 1) {
1172  assert(insertOp.getPosition().size() == 1 && "expected 1 index");
1173  int64_t pos = insertOp.getPosition()[0].cast<IntegerAttr>().getInt();
1174  rewriter.setInsertionPoint(insertOp);
1175  rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
1176  insertOp, insertOp.getSource(), insertOp.getDest(),
1177  rewriter.create<arith::ConstantIndexOp>(loc, pos));
1178  return success();
1179  }
1180 
1181  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1182  // There is no distribution, this is a broadcast. Simply move the insert
1183  // out of the warp op.
1184  SmallVector<size_t> newRetIndices;
1185  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1186  rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1187  {insertOp.getSourceType(), insertOp.getDestVectorType()},
1188  newRetIndices);
1189  rewriter.setInsertionPointAfter(newWarpOp);
1190  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1191  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1192  Value newResult = rewriter.create<vector::InsertOp>(
1193  loc, distributedSrc, distributedDest, insertOp.getPosition());
1194  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1195  newResult);
1196  return success();
1197  }
1198 
1199  // Find the distributed dimension. There should be exactly one.
1200  auto distrDestType =
1201  warpOp.getResult(operandNumber).getType().cast<VectorType>();
1202  auto yieldedType = operand->get().getType().cast<VectorType>();
1203  int64_t distrDestDim = -1;
1204  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1205  if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1206  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1207  // support distributing multiple dimensions in the future.
1208  assert(distrDestDim == -1 && "found multiple distributed dims");
1209  distrDestDim = i;
1210  }
1211  }
1212  assert(distrDestDim != -1 && "could not find distributed dimension");
1213 
1214  // Compute the distributed source vector type.
1215  VectorType srcVecType = insertOp.getSourceType().cast<VectorType>();
1216  SmallVector<int64_t> distrSrcShape(srcVecType.getShape().begin(),
1217  srcVecType.getShape().end());
1218  // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1219  // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1220  // insert a smaller vector<3xf32>.
1221  // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1222  // case, one lane will insert the source vector<96xf32>. The other
1223  // lanes will not do anything.
1224  int64_t distrSrcDim = distrDestDim - insertOp.getPosition().size();
1225  if (distrSrcDim >= 0)
1226  distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1227  auto distrSrcType =
1228  VectorType::get(distrSrcShape, distrDestType.getElementType());
1229 
1230  // Yield source and dest vectors from warp op.
1231  SmallVector<size_t> newRetIndices;
1232  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1233  rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1234  {distrSrcType, distrDestType}, newRetIndices);
1235  rewriter.setInsertionPointAfter(newWarpOp);
1236  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1237  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1238 
1239  // Insert into the distributed vector.
1240  Value newResult;
1241  if (distrSrcDim >= 0) {
1242  // Every lane inserts a small piece.
1243  newResult = rewriter.create<vector::InsertOp>(
1244  loc, distributedSrc, distributedDest, insertOp.getPosition());
1245  } else {
1246  // One lane inserts the entire source vector.
1247  int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1248  SmallVector<int64_t> newPos = llvm::to_vector(
1249  llvm::map_range(insertOp.getPosition(), [](Attribute attr) {
1250  return attr.cast<IntegerAttr>().getInt();
1251  }));
1252  // tid of inserting lane: pos / elementsPerLane
1253  Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1254  loc, newPos[distrDestDim] / elementsPerLane);
1255  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1256  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1257  // Insert position: pos % elementsPerLane
1258  newPos[distrDestDim] %= elementsPerLane;
1259  auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1260  Value newInsert = builder.create<vector::InsertOp>(
1261  loc, distributedSrc, distributedDest, newPos);
1262  builder.create<scf::YieldOp>(loc, newInsert);
1263  };
1264  auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1265  builder.create<scf::YieldOp>(loc, distributedDest);
1266  };
1267  newResult = rewriter
1268  .create<scf::IfOp>(loc, isInsertingLane,
1269  /*thenBuilder=*/insertingBuilder,
1270  /*elseBuilder=*/nonInsertingBuilder)
1271  .getResult(0);
1272  }
1273 
1274  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1275  return success();
1276  }
1277 };
1278 
1279 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1280 /// the scf.ForOp is the last operation in the region so that it doesn't change
1281 /// the order of execution. This creates a new scf.for region after the
1282 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1283 /// WarpExecuteOnLane0Op region. Example:
1284 /// ```
1285 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1286 /// ...
1287 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1288 /// -> (vector<128xf32>) {
1289 /// ...
1290 /// scf.yield %r : vector<128xf32>
1291 /// }
1292 /// vector.yield %v1 : vector<128xf32>
1293 /// }
1294 /// ```
1295 /// To:
1296 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1297 /// ...
1298 /// vector.yield %v : vector<128xf32>
1299 /// }
1300 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1301 /// -> (vector<4xf32>) {
1302 /// %iw = vector.warp_execute_on_lane_0(%laneid)
1303 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1304 /// ^bb0(%arg: vector<128xf32>):
1305 /// ...
1306 /// vector.yield %ir : vector<128xf32>
1307 /// }
1308 /// scf.yield %iw : vector<4xf32>
1309 /// }
1310 /// ```
1311 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
1312 
1313  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1314  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1315  distributionMapFn(std::move(fn)) {}
1317  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1318  PatternRewriter &rewriter) const override {
1319  auto yield = cast<vector::YieldOp>(
1320  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1321  // Only pick up forOp if it is the last op in the region.
1322  Operation *lastNode = yield->getPrevNode();
1323  auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1324  if (!forOp)
1325  return failure();
1326  // Collect Values that come from the warp op but are outside the forOp.
1327  // Those Value needs to be returned by the original warpOp and passed to the
1328  // new op.
1329  llvm::SmallSetVector<Value, 32> escapingValues;
1330  SmallVector<Type> inputTypes;
1331  SmallVector<Type> distTypes;
1333  forOp.getBodyRegion(), [&](OpOperand *operand) {
1334  Operation *parent = operand->get().getParentRegion()->getParentOp();
1335  if (warpOp->isAncestor(parent)) {
1336  if (!escapingValues.insert(operand->get()))
1337  return;
1338  Type distType = operand->get().getType();
1339  if (auto vecType = distType.cast<VectorType>()) {
1340  AffineMap map = distributionMapFn(operand->get());
1341  distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1342  }
1343  inputTypes.push_back(operand->get().getType());
1344  distTypes.push_back(distType);
1345  }
1346  });
1347 
1348  SmallVector<size_t> newRetIndices;
1349  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1350  rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1351  newRetIndices);
1352  yield = cast<vector::YieldOp>(
1353  newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1354 
1355  SmallVector<Value> newOperands;
1356  SmallVector<unsigned> resultIdx;
1357  // Collect all the outputs coming from the forOp.
1358  for (OpOperand &yieldOperand : yield->getOpOperands()) {
1359  if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1360  continue;
1361  auto forResult = yieldOperand.get().cast<OpResult>();
1362  newOperands.push_back(
1363  newWarpOp.getResult(yieldOperand.getOperandNumber()));
1364  yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
1365  resultIdx.push_back(yieldOperand.getOperandNumber());
1366  }
1367 
1368  OpBuilder::InsertionGuard g(rewriter);
1369  rewriter.setInsertionPointAfter(newWarpOp);
1370 
1371  // Create a new for op outside the region with a WarpExecuteOnLane0Op region
1372  // inside.
1373  auto newForOp = rewriter.create<scf::ForOp>(
1374  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1375  forOp.getStep(), newOperands);
1376  rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
1377 
1378  SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1379  newForOp.getRegionIterArgs().end());
1380  SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1381  forOp.getResultTypes().end());
1382  llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1383  for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1384  warpInput.push_back(newWarpOp.getResult(retIdx));
1385  argIndexMapping[escapingValues[i]] = warpInputType.size();
1386  warpInputType.push_back(inputTypes[i]);
1387  }
1388  auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1389  newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1390  newWarpOp.getWarpSize(), warpInput, warpInputType);
1391 
1392  SmallVector<Value> argMapping;
1393  argMapping.push_back(newForOp.getInductionVar());
1394  for (Value args : innerWarp.getBody()->getArguments()) {
1395  argMapping.push_back(args);
1396  }
1397  argMapping.resize(forOp.getBody()->getNumArguments());
1398  SmallVector<Value> yieldOperands;
1399  for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1400  yieldOperands.push_back(operand);
1401  rewriter.eraseOp(forOp.getBody()->getTerminator());
1402  rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1403  rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
1404  rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
1405  rewriter.setInsertionPointAfter(innerWarp);
1406  if (!innerWarp.getResults().empty())
1407  rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1408  rewriter.eraseOp(forOp);
1409  // Replace the warpOp result coming from the original ForOp.
1410  for (const auto &res : llvm::enumerate(resultIdx)) {
1411  rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1412  newForOp.getResult(res.index()));
1413  newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1414  }
1415  newForOp.walk([&](Operation *op) {
1416  for (OpOperand &operand : op->getOpOperands()) {
1417  auto it = argIndexMapping.find(operand.get());
1418  if (it == argIndexMapping.end())
1419  continue;
1420  operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1421  }
1422  });
1423  return success();
1424  }
1425 
1426 private:
1427  DistributionMapFn distributionMapFn;
1428 };
1429 
1430 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1431 /// The vector is reduced in parallel. Currently limited to vector size matching
1432 /// the warpOp size. E.g.:
1433 /// ```
1434 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1435 /// %0 = "some_def"() : () -> (vector<32xf32>)
1436 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1437 /// vector_ext.yield %1 : f32
1438 /// }
1439 /// ```
1440 /// is lowered to:
1441 /// ```
1442 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1443 /// %1 = "some_def"() : () -> (vector<32xf32>)
1444 /// vector_ext.yield %1 : vector<32xf32>
1445 /// }
1446 /// %a = vector.extract %0[0] : vector<1xf32>
1447 /// %r = ("warp.reduction %a")
1448 /// ```
1449 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
1450  WarpOpReduction(MLIRContext *context,
1451  DistributedReductionFn distributedReductionFn,
1452  PatternBenefit benefit = 1)
1453  : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
1454  distributedReductionFn(std::move(distributedReductionFn)) {}
1455 
1456  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1457  PatternRewriter &rewriter) const override {
1458  OpOperand *yieldOperand = getWarpResult(
1459  warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
1460  if (!yieldOperand)
1461  return failure();
1462 
1463  auto reductionOp =
1464  cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1465  auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
1466  // Only rank 1 vectors supported.
1467  if (vectorType.getRank() != 1)
1468  return rewriter.notifyMatchFailure(
1469  warpOp, "Only rank 1 reductions can be distributed.");
1470  // Only warp_size-sized vectors supported.
1471  if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1472  return rewriter.notifyMatchFailure(
1473  warpOp, "Reduction vector dimension must match was size.");
1474  if (!reductionOp.getType().isIntOrFloat())
1475  return rewriter.notifyMatchFailure(
1476  warpOp, "Reduction distribution currently only supports floats and "
1477  "integer types.");
1478 
1479  int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1480  // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1481  unsigned operandIndex = yieldOperand->getOperandNumber();
1482  SmallVector<Value> yieldValues = {reductionOp.getVector()};
1483  SmallVector<Type> retTypes = {
1484  VectorType::get({numElements}, reductionOp.getType())};
1485  if (reductionOp.getAcc()) {
1486  yieldValues.push_back(reductionOp.getAcc());
1487  retTypes.push_back(reductionOp.getAcc().getType());
1488  }
1489  SmallVector<size_t> newRetIndices;
1490  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1491  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1492  rewriter.setInsertionPointAfter(newWarpOp);
1493 
1494  // Obtain data to reduce for a single lane.
1495  Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1496  // Distribute and reduce across threads.
1497  Value fullReduce =
1498  distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1499  reductionOp.getKind(), newWarpOp.getWarpSize());
1500  if (reductionOp.getAcc()) {
1501  fullReduce = vector::makeArithReduction(
1502  rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1503  newWarpOp.getResult(newRetIndices[1]));
1504  }
1505  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1506  return success();
1507  }
1508 
1509 private:
1510  DistributedReductionFn distributedReductionFn;
1511 };
1512 
1513 } // namespace
1514 
1516  RewritePatternSet &patterns,
1518  patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
1519 }
1520 
1521 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1522  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1523  PatternBenefit benefit) {
1524  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
1525  benefit);
1526 }
1527 
1528 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1529  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1530  const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
1531  patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
1532  WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
1533  WarpOpConstant, WarpOpInsertElement, WarpOpInsert>(
1534  patterns.getContext(), benefit);
1535  patterns.add<WarpOpExtractElement>(patterns.getContext(),
1536  warpShuffleFromIdxFn, benefit);
1537  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
1538  benefit);
1539 }
1540 
1541 void mlir::vector::populateDistributeReduction(
1542  RewritePatternSet &patterns,
1543  const DistributedReductionFn &distributedReductionFn,
1544  PatternBenefit benefit) {
1545  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
1546  benefit);
1547 }
1548 
1549 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1550  Block *body = warpOp.getBody();
1551 
1552  // Keep track of the ops we want to hoist.
1553  llvm::SmallSetVector<Operation *, 8> opsToMove;
1554 
1555  // Helper to check if a value is or will be defined outside of the region.
1556  auto isDefinedOutsideOfBody = [&](Value value) {
1557  auto *definingOp = value.getDefiningOp();
1558  return (definingOp && opsToMove.count(definingOp)) ||
1559  warpOp.isDefinedOutsideOfRegion(value);
1560  };
1561 
1562  // Do not use walk here, as we do not want to go into nested regions and hoist
1563  // operations from there.
1564  for (auto &op : body->without_terminator()) {
1565  bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1566  return result.getType().isa<VectorType>();
1567  });
1568  if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1569  opsToMove.insert(&op);
1570  }
1571 
1572  // Move all the ops marked as uniform outside of the region.
1573  for (Operation *op : opsToMove)
1574  op->moveBefore(warpOp);
1575 }
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, const std::function< bool(Operation *)> &fn)
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes)
Helper to create a new WarpExecuteOnLane0Op with different signature.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, llvm::SmallVector< size_t > &indices)
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
unsigned getPosition() const
Definition: AffineExpr.cpp:325
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:821
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:43
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:345
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:337
unsigned getNumResults() const
Definition: AffineMap.cpp:332
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:482
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U cast() const
Definition: Attributes.h:176
This class represents an argument of a Block.
Definition: Value.h:304
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
Operation & front()
Definition: Block.h:142
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:198
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:347
MLIRContext * getContext() const
Definition: Builders.h:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:140
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:202
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:520
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:412
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:379
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:417
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:501
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:393
This class represents an operand of an operation.
Definition: Value.h:255
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This is a value defined by a result of an operation.
Definition: Value.h:442
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:386
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:537
unsigned getNumOperands()
Definition: Operation.h:325
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:418
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:103
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:362
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:357
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:443
result_range getResults()
Definition: Operation.h:394
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:597
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:549
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:558
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:482
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
U dyn_cast() const
Definition: Types.h:311
bool isa() const
Definition: Types.h:301
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:207
Type getType() const
Return the type of this value.
Definition: Value.h:122
U dyn_cast() const
Definition: Value.h:103
U cast() const
Definition: Value.h:113
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1147
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, Value mask=Value())
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:329
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1171
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:32
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:512
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.