MLIR  18.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/AffineExpr.h"
18 #include "llvm/ADT/SetVector.h"
19 #include <numeric>
20 #include <utility>
21 
22 using namespace mlir;
23 using namespace mlir::vector;
24 
25 /// Currently the distribution map is implicit based on the vector shape. In the
26 /// future it will be part of the op.
27 /// Example:
28 /// ```
29 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
30 /// ...
31 /// vector.yield %3 : vector<32x16x64xf32>
32 /// }
33 /// ```
34 /// Would have an implicit map of:
35 /// `(d0, d1, d2) -> (d0, d2)`
36 static AffineMap calculateImplicitMap(VectorType sequentialType,
37  VectorType distributedType) {
39  perm.reserve(1);
40  // Check which dimensions of the sequential type are different than the
41  // dimensions of the distributed type to know the distributed dimensions. Then
42  // associate each distributed dimension to an ID in order.
43  for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) {
44  if (sequentialType.getDimSize(i) != distributedType.getDimSize(i))
45  perm.push_back(getAffineDimExpr(i, distributedType.getContext()));
46  }
47  auto map = AffineMap::get(sequentialType.getRank(), 0, perm,
48  distributedType.getContext());
49  return map;
50 }
51 
52 namespace {
53 
54 /// Helper struct to create the load / store operations that permit transit
55 /// through the parallel / sequential and the sequential / parallel boundaries
56 /// when performing `rewriteWarpOpToScfFor`.
57 ///
58 /// The vector distribution dimension is inferred from the vector types.
59 struct DistributedLoadStoreHelper {
60  DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal,
61  Value laneId, Value zero)
62  : sequentialVal(sequentialVal), distributedVal(distributedVal),
63  laneId(laneId), zero(zero) {
64  sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType());
65  distributedVectorType = dyn_cast<VectorType>(distributedVal.getType());
66  if (sequentialVectorType && distributedVectorType)
67  distributionMap =
68  calculateImplicitMap(sequentialVectorType, distributedVectorType);
69  }
70 
71  Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) {
72  int64_t distributedSize = distributedVectorType.getDimSize(index);
74  return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize,
75  ArrayRef<Value>{laneId});
76  }
77 
78  /// Create a store during the process of distributing the
79  /// `vector.warp_execute_on_thread_0` op.
80  /// Vector distribution assumes the following convention regarding the
81  /// temporary buffers that are created to transition values. This **must**
82  /// be properly specified in the `options.warpAllocationFn`:
83  /// 1. scalars of type T transit through a memref<1xT>.
84  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
85  Operation *buildStore(RewriterBase &b, Location loc, Value val,
86  Value buffer) {
87  assert((val == distributedVal || val == sequentialVal) &&
88  "Must store either the preregistered distributed or the "
89  "preregistered sequential value.");
90  // Scalar case can directly use memref.store.
91  if (!isa<VectorType>(val.getType()))
92  return b.create<memref::StoreOp>(loc, val, buffer, zero);
93 
94  // Vector case must use vector::TransferWriteOp which will later lower to
95  // vector.store of memref.store depending on further lowerings.
96  int64_t rank = sequentialVectorType.getRank();
97  SmallVector<Value> indices(rank, zero);
98  if (val == distributedVal) {
99  for (auto dimExpr : distributionMap.getResults()) {
100  int64_t index = dimExpr.cast<AffineDimExpr>().getPosition();
101  indices[index] = buildDistributedOffset(b, loc, index);
102  }
103  }
104  SmallVector<bool> inBounds(indices.size(), true);
105  return b.create<vector::TransferWriteOp>(
106  loc, val, buffer, indices,
107  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
108  }
109 
110  /// Create a load during the process of distributing the
111  /// `vector.warp_execute_on_thread_0` op.
112  /// Vector distribution assumes the following convention regarding the
113  /// temporary buffers that are created to transition values. This **must**
114  /// be properly specified in the `options.warpAllocationFn`:
115  /// 1. scalars of type T transit through a memref<1xT>.
116  /// 2. vectors of type V<shapexT> transit through a memref<shapexT>
117  ///
118  /// When broadcastMode is true, the load is not distributed to account for
119  /// the broadcast semantics of the `vector.warp_execute_on_lane_0` op.
120  ///
121  /// Example:
122  ///
123  /// ```
124  /// %r = vector.warp_execute_on_lane_0(...) -> (f32) {
125  /// vector.yield %cst : f32
126  /// }
127  /// // Both types are f32. The constant %cst is broadcasted to all lanes.
128  /// ```
129  /// This behavior described in more detail in the documentation of the op.
130  Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) {
131 
132  // Scalar case can directly use memref.store.
133  if (!isa<VectorType>(type))
134  return b.create<memref::LoadOp>(loc, buffer, zero);
135 
136  // Other cases must be vector atm.
137  // Vector case must use vector::TransferReadOp which will later lower to
138  // vector.read of memref.read depending on further lowerings.
139  assert((type == distributedVectorType || type == sequentialVectorType) &&
140  "Must store either the preregistered distributed or the "
141  "preregistered sequential type.");
142  SmallVector<Value> indices(sequentialVectorType.getRank(), zero);
143  if (type == distributedVectorType) {
144  for (auto dimExpr : distributionMap.getResults()) {
145  int64_t index = dimExpr.cast<AffineDimExpr>().getPosition();
146  indices[index] = buildDistributedOffset(b, loc, index);
147  }
148  }
149  SmallVector<bool> inBounds(indices.size(), true);
150  return b.create<vector::TransferReadOp>(
151  loc, cast<VectorType>(type), buffer, indices,
152  ArrayRef<bool>(inBounds.begin(), inBounds.end()));
153  }
154 
155  Value sequentialVal, distributedVal, laneId, zero;
156  VectorType sequentialVectorType, distributedVectorType;
157  AffineMap distributionMap;
158 };
159 
160 } // namespace
161 
162 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
163 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
164  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
165  ValueRange newYieldedValues, TypeRange newReturnTypes) {
166  // Create a new op before the existing one, with the extra operands.
167  OpBuilder::InsertionGuard g(rewriter);
168  rewriter.setInsertionPoint(warpOp);
169  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
170  warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
171  warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
172 
173  Region &opBody = warpOp.getBodyRegion();
174  Region &newOpBody = newWarpOp.getBodyRegion();
175  Block &newOpFirstBlock = newOpBody.front();
176  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
177  rewriter.eraseBlock(&newOpFirstBlock);
178  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
179  "expected WarpOp with single block");
180 
181  auto yield =
182  cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
183 
184  rewriter.updateRootInPlace(
185  yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
186  return newWarpOp;
187 }
188 
189 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
190 /// `indices` return the index of each new output.
191 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
192  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
193  ValueRange newYieldedValues, TypeRange newReturnTypes,
194  llvm::SmallVector<size_t> &indices) {
195  SmallVector<Type> types(warpOp.getResultTypes().begin(),
196  warpOp.getResultTypes().end());
197  auto yield = cast<vector::YieldOp>(
198  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
199  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
200  yield.getOperands().end());
201  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
202  if (yieldValues.insert(std::get<0>(newRet))) {
203  types.push_back(std::get<1>(newRet));
204  indices.push_back(yieldValues.size() - 1);
205  } else {
206  // If the value already exit the region don't create a new output.
207  for (auto [idx, yieldOperand] :
208  llvm::enumerate(yieldValues.getArrayRef())) {
209  if (yieldOperand == std::get<0>(newRet)) {
210  indices.push_back(idx);
211  break;
212  }
213  }
214  }
215  }
216  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
217  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
218  rewriter, warpOp, yieldValues.getArrayRef(), types);
219  rewriter.replaceOp(warpOp,
220  newWarpOp.getResults().take_front(warpOp.getNumResults()));
221  return newWarpOp;
222 }
223 
224 /// Helper to know if an op can be hoisted out of the region.
225 static bool canBeHoisted(Operation *op,
226  function_ref<bool(Value)> definedOutside) {
227  return llvm::all_of(op->getOperands(), definedOutside) &&
228  isMemoryEffectFree(op) && op->getNumRegions() == 0;
229 }
230 
231 /// Return a value yielded by `warpOp` which statifies the filter lamdba
232 /// condition and is not dead.
233 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
234  const std::function<bool(Operation *)> &fn) {
235  auto yield = cast<vector::YieldOp>(
236  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
237  for (OpOperand &yieldOperand : yield->getOpOperands()) {
238  Value yieldValues = yieldOperand.get();
239  Operation *definedOp = yieldValues.getDefiningOp();
240  if (definedOp && fn(definedOp)) {
241  if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
242  return &yieldOperand;
243  }
244  }
245  return {};
246 }
247 
248 // Clones `op` into a new operation that takes `operands` and returns
249 // `resultTypes`.
251  Location loc, Operation *op,
252  ArrayRef<Value> operands,
253  ArrayRef<Type> resultTypes) {
254  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
255  op->getAttrs());
256  return rewriter.create(res);
257 }
258 
259 namespace {
260 
261 /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single
262 /// thread `laneId` executes the entirety of the computation.
263 ///
264 /// After the transformation:
265 /// - the IR within the scf.if op can be thought of as executing sequentially
266 /// (from the point of view of threads along `laneId`).
267 /// - the IR outside of the scf.if op can be thought of as executing in
268 /// parallel (from the point of view of threads along `laneId`).
269 ///
270 /// Values that need to transit through the parallel / sequential and the
271 /// sequential / parallel boundaries do so via reads and writes to a temporary
272 /// memory location.
273 ///
274 /// The transformation proceeds in multiple steps:
275 /// 1. Create the scf.if op.
276 /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads
277 /// within the scf.if to transit the values captured from above.
278 /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are
279 /// consistent within the scf.if.
280 /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if.
281 /// 5. Insert appropriate writes within scf.if and reads after the scf.if to
282 /// transit the values returned by the op.
283 /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are
284 /// consistent after the scf.if.
285 /// 7. Perform late cleanups.
286 ///
287 /// All this assumes the vector distribution occurs along the most minor
288 /// distributed vector dimension.
289 struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
290  WarpOpToScfIfPattern(MLIRContext *context,
292  PatternBenefit benefit = 1)
293  : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
294  options(options) {}
295 
296  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
297  PatternRewriter &rewriter) const override {
298  assert(warpOp.getBodyRegion().hasOneBlock() &&
299  "expected WarpOp with single block");
300  Block *warpOpBody = &warpOp.getBodyRegion().front();
301  Location loc = warpOp.getLoc();
302 
303  // Passed all checks. Start rewriting.
304  OpBuilder::InsertionGuard g(rewriter);
305  rewriter.setInsertionPoint(warpOp);
306 
307  // Step 1: Create scf.if op.
308  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
309  Value isLane0 = rewriter.create<arith::CmpIOp>(
310  loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0);
311  auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
312  /*withElseRegion=*/false);
313  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
314 
315  // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and
316  // reads within the scf.if to transit the values captured from above.
317  SmallVector<Value> bbArgReplacements;
318  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
319  Value sequentialVal = warpOpBody->getArgument(it.index());
320  Value distributedVal = it.value();
321  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
322  warpOp.getLaneid(), c0);
323 
324  // Create buffer before the ifOp.
325  rewriter.setInsertionPoint(ifOp);
326  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
327  sequentialVal.getType());
328  // Store distributed vector into buffer, before the ifOp.
329  helper.buildStore(rewriter, loc, distributedVal, buffer);
330  // Load sequential vector from buffer, inside the ifOp.
331  rewriter.setInsertionPointToStart(ifOp.thenBlock());
332  bbArgReplacements.push_back(
333  helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer));
334  }
335 
336  // Step 3. Insert sync after all the stores and before all the loads.
337  if (!warpOp.getArgs().empty()) {
338  rewriter.setInsertionPoint(ifOp);
339  options.warpSyncronizationFn(loc, rewriter, warpOp);
340  }
341 
342  // Step 4. Move body of warpOp to ifOp.
343  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
344 
345  // Step 5. Insert appropriate writes within scf.if and reads after the
346  // scf.if to transit the values returned by the op.
347  // TODO: at this point, we can reuse the shared memory from previous
348  // buffers.
349  SmallVector<Value> replacements;
350  auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
351  Location yieldLoc = yieldOp.getLoc();
352  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
353  Value sequentialVal = it.value();
354  Value distributedVal = warpOp->getResult(it.index());
355  DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
356  warpOp.getLaneid(), c0);
357 
358  // Create buffer before the ifOp.
359  rewriter.setInsertionPoint(ifOp);
360  Value buffer = options.warpAllocationFn(loc, rewriter, warpOp,
361  sequentialVal.getType());
362 
363  // Store yielded value into buffer, inside the ifOp, before the
364  // terminator.
365  rewriter.setInsertionPoint(yieldOp);
366  helper.buildStore(rewriter, loc, sequentialVal, buffer);
367 
368  // Load distributed value from buffer, after the warpOp.
369  rewriter.setInsertionPointAfter(ifOp);
370  // Result type and yielded value type are the same. This is a broadcast.
371  // E.g.:
372  // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
373  // vector.yield %cst : f32
374  // }
375  // Both types are f32. The constant %cst is broadcasted to all lanes.
376  // This is described in more detail in the documentation of the op.
377  replacements.push_back(
378  helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer));
379  }
380 
381  // Step 6. Insert sync after all the stores and before all the loads.
382  if (!yieldOp.getOperands().empty()) {
383  rewriter.setInsertionPointAfter(ifOp);
384  options.warpSyncronizationFn(loc, rewriter, warpOp);
385  }
386 
387  // Step 7. Delete terminator and add empty scf.yield.
388  rewriter.eraseOp(yieldOp);
389  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
390  rewriter.create<scf::YieldOp>(yieldLoc);
391 
392  // Compute replacements for WarpOp results.
393  rewriter.replaceOp(warpOp, replacements);
394 
395  return success();
396  }
397 
398 private:
400 };
401 
402 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
403 /// op with the proper return type.
404 /// The new write op is updated to write the result of the new warp execute op.
405 /// The old `writeOp` is deleted.
406 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
407  WarpExecuteOnLane0Op warpOp,
408  vector::TransferWriteOp writeOp,
409  VectorType targetType) {
410  assert(writeOp->getParentOp() == warpOp &&
411  "write must be nested immediately under warp");
412  OpBuilder::InsertionGuard g(rewriter);
413  SmallVector<size_t> newRetIndices;
414  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
415  rewriter, warpOp, ValueRange{{writeOp.getVector()}},
416  TypeRange{targetType}, newRetIndices);
417  rewriter.setInsertionPointAfter(newWarpOp);
418  auto newWriteOp =
419  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
420  rewriter.eraseOp(writeOp);
421  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
422  return newWriteOp;
423 }
424 
425 /// Return the distributed vector type based on the original type and the
426 /// distribution map. The map is expected to have a dimension equal to the
427 /// original type rank and should be a projection where the results are the
428 /// distributed dimensions. The number of results should be equal to the number
429 /// of warp sizes which is currently limited to 1.
430 /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
431 /// and a warp size of 16 would distribute the second dimension (associated to
432 /// d1) and return vector<16x2x64>
433 static VectorType getDistributedType(VectorType originalType, AffineMap map,
434  int64_t warpSize) {
435  if (map.getNumResults() != 1)
436  return VectorType();
437  SmallVector<int64_t> targetShape(originalType.getShape().begin(),
438  originalType.getShape().end());
439  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
440  unsigned position = map.getDimPosition(i);
441  if (targetShape[position] % warpSize != 0)
442  return VectorType();
443  targetShape[position] = targetShape[position] / warpSize;
444  }
445  VectorType targetType =
446  VectorType::get(targetShape, originalType.getElementType());
447  return targetType;
448 }
449 
450 /// Distribute transfer_write ops based on the affine map returned by
451 /// `distributionMapFn`.
452 /// Example:
453 /// ```
454 /// %0 = vector.warp_execute_on_lane_0(%id){
455 /// ...
456 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
457 /// vector.yield
458 /// }
459 /// ```
460 /// To
461 /// ```
462 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
463 /// ...
464 /// vector.yield %v : vector<32xf32>
465 /// }
466 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
467 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
468  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
469  PatternBenefit b = 1)
470  : OpRewritePattern<vector::TransferWriteOp>(ctx, b),
471  distributionMapFn(std::move(fn)) {}
472 
473  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
474  /// are multiples of the distribution ratio are supported at the moment.
475  LogicalResult tryDistributeOp(RewriterBase &rewriter,
476  vector::TransferWriteOp writeOp,
477  WarpExecuteOnLane0Op warpOp) const {
478  VectorType writtenVectorType = writeOp.getVectorType();
479 
480  // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
481  // to separate it from the rest.
482  if (writtenVectorType.getRank() == 0)
483  return failure();
484 
485  // 2. Compute the distributed type.
486  AffineMap map = distributionMapFn(writeOp.getVector());
487  VectorType targetType =
488  getDistributedType(writtenVectorType, map, warpOp.getWarpSize());
489  if (!targetType)
490  return failure();
491 
492  // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from
493  // the rest.
494  vector::TransferWriteOp newWriteOp =
495  cloneWriteOp(rewriter, warpOp, writeOp, targetType);
496 
497  // 4. Reindex the write using the distribution map.
498  auto newWarpOp =
499  newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
500  rewriter.setInsertionPoint(newWriteOp);
501  AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
502  Location loc = newWriteOp.getLoc();
503  SmallVector<Value> indices(newWriteOp.getIndices().begin(),
504  newWriteOp.getIndices().end());
505  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
506  AffineExpr d0, d1;
507  bindDims(newWarpOp.getContext(), d0, d1);
508  auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
509  if (!indexExpr)
510  continue;
511  unsigned indexPos = indexExpr.getPosition();
512  unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
513  auto scale =
514  rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos));
515  indices[indexPos] = affine::makeComposedAffineApply(
516  rewriter, loc, d0 + scale * d1,
517  {indices[indexPos], newWarpOp.getLaneid()});
518  }
519  newWriteOp.getIndicesMutable().assign(indices);
520 
521  return success();
522  }
523 
524  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
525  LogicalResult tryExtractOp(RewriterBase &rewriter,
526  vector::TransferWriteOp writeOp,
527  WarpExecuteOnLane0Op warpOp) const {
528  Location loc = writeOp.getLoc();
529  VectorType vecType = writeOp.getVectorType();
530 
531  // Only sink out vector of 1 element for now to not serialize large vector
532  // store. This can later be controlled by user.
533  if (vecType.getNumElements() != 1)
534  return failure();
535 
536  // Do not process warp ops that contain only TransferWriteOps.
537  if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
538  return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
539  }))
540  return failure();
541 
542  SmallVector<Value> yieldValues = {writeOp.getVector()};
543  SmallVector<Type> retTypes = {vecType};
544  SmallVector<size_t> newRetIndices;
545  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
546  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
547  rewriter.setInsertionPointAfter(newWarpOp);
548 
549  // Create a second warp op that contains only writeOp.
550  auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
551  loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
552  Block &body = secondWarpOp.getBodyRegion().front();
553  rewriter.setInsertionPointToStart(&body);
554  auto newWriteOp =
555  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
556  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
557  rewriter.eraseOp(writeOp);
558  rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
559  return success();
560  }
561 
562  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
563  PatternRewriter &rewriter) const override {
564  // Ops with mask not supported yet.
565  if (writeOp.getMask())
566  return failure();
567 
568  auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
569  if (!warpOp)
570  return failure();
571 
572  // There must be no op with a side effect after writeOp.
573  Operation *nextOp = writeOp.getOperation();
574  while ((nextOp = nextOp->getNextNode()))
575  if (!isMemoryEffectFree(nextOp))
576  return failure();
577 
578  if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
579  return writeOp.getVector() == value ||
580  warpOp.isDefinedOutsideOfRegion(value);
581  }))
582  return failure();
583 
584  if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
585  return success();
586 
587  if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
588  return success();
589 
590  return failure();
591  }
592 
593 private:
594  DistributionMapFn distributionMapFn;
595 };
596 
597 /// Sink out elementwise op feeding into a warp op yield.
598 /// ```
599 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
600 /// ...
601 /// %3 = arith.addf %1, %2 : vector<32xf32>
602 /// vector.yield %3 : vector<32xf32>
603 /// }
604 /// ```
605 /// To
606 /// ```
607 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
608 /// vector<1xf32>, vector<1xf32>) {
609 /// ...
610 /// %4 = arith.addf %2, %3 : vector<32xf32>
611 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
612 /// vector<32xf32>
613 /// }
614 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
615 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
617  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
618  PatternRewriter &rewriter) const override {
619  OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
621  });
622  if (!yieldOperand)
623  return failure();
624  Operation *elementWise = yieldOperand->get().getDefiningOp();
625  unsigned operandIndex = yieldOperand->getOperandNumber();
626  Value distributedVal = warpOp.getResult(operandIndex);
627  SmallVector<Value> yieldValues;
628  SmallVector<Type> retTypes;
629  Location loc = warpOp.getLoc();
630  for (OpOperand &operand : elementWise->getOpOperands()) {
631  Type targetType;
632  if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) {
633  // If the result type is a vector, the operands must also be vectors.
634  auto operandType = cast<VectorType>(operand.get().getType());
635  targetType =
636  VectorType::get(vecType.getShape(), operandType.getElementType());
637  } else {
638  auto operandType = operand.get().getType();
639  assert(!isa<VectorType>(operandType) &&
640  "unexpected yield of vector from op with scalar result type");
641  targetType = operandType;
642  }
643  retTypes.push_back(targetType);
644  yieldValues.push_back(operand.get());
645  }
646  SmallVector<size_t> newRetIndices;
647  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
648  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
649  rewriter.setInsertionPointAfter(newWarpOp);
650  SmallVector<Value> newOperands(elementWise->getOperands().begin(),
651  elementWise->getOperands().end());
652  for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
653  newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
654  }
655  OpBuilder::InsertionGuard g(rewriter);
656  rewriter.setInsertionPointAfter(newWarpOp);
658  rewriter, loc, elementWise, newOperands,
659  {newWarpOp.getResult(operandIndex).getType()});
660  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex),
661  newOp->getResult(0));
662  return success();
663  }
664 };
665 
666 /// Sink out splat constant op feeding into a warp op yield.
667 /// ```
668 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
669 /// ...
670 /// %cst = arith.constant dense<2.0> : vector<32xf32>
671 /// vector.yield %cst : vector<32xf32>
672 /// }
673 /// ```
674 /// To
675 /// ```
676 /// vector.warp_execute_on_lane_0(%arg0 {
677 /// ...
678 /// }
679 /// %0 = arith.constant dense<2.0> : vector<1xf32>
680 struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
682  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
683  PatternRewriter &rewriter) const override {
684  OpOperand *yieldOperand = getWarpResult(
685  warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); });
686  if (!yieldOperand)
687  return failure();
688  auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
689  auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue());
690  if (!dense)
691  return failure();
692  unsigned operandIndex = yieldOperand->getOperandNumber();
693  Attribute scalarAttr = dense.getSplatValue<Attribute>();
694  auto newAttr = DenseElementsAttr::get(
695  cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
696  Location loc = warpOp.getLoc();
697  rewriter.setInsertionPointAfter(warpOp);
698  Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
699  rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
700  return success();
701  }
702 };
703 
704 /// Delinearize the given `laneId` into multiple dimensions, where each
705 /// dimension's size is determined by `originalShape` and `distributedShape`
706 /// together. This function expects the total numbers of threads needed for
707 /// distribution is equal to `warpSize`. Returns true and updates
708 /// `delinearizedIds` if so.
709 bool delinearizeLaneId(OpBuilder &builder, Location loc,
710  ArrayRef<int64_t> originalShape,
711  ArrayRef<int64_t> distributedShape, int64_t warpSize,
712  Value laneId, SmallVectorImpl<Value> &delinearizedIds) {
713  // If the original shape and the distributed shape is the same, we don't
714  // distribute at all--every thread is handling the whole. For such case, we
715  // should not rely on lane IDs later. So just return an empty lane ID vector.
716  if (originalShape == distributedShape) {
717  delinearizedIds.clear();
718  return true;
719  }
720 
721  SmallVector<int64_t> sizes;
722  for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) {
723  if (large % small != 0)
724  return false;
725  sizes.push_back(large / small);
726  }
727  if (std::accumulate(sizes.begin(), sizes.end(), 1,
728  std::multiplies<int64_t>()) != warpSize)
729  return false;
730 
731  AffineExpr s0, s1;
732  bindSymbols(builder.getContext(), s0, s1);
733 
734  int64_t usedThreads = 1;
735 
736  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
737  delinearizedIds.assign(sizes.size(), zero);
738 
739  for (int i = sizes.size() - 1; i >= 0; --i) {
740  usedThreads *= sizes[i];
741  if (usedThreads == warpSize) {
742  // We've used up all available threads. Don't need to perform modulo
743  // anymore. And we can stop the calculation for further dimensions.
744  delinearizedIds[i] = laneId;
745  break;
746  }
747  delinearizedIds[i] =
748  affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId});
750  builder, loc, s0.floorDiv(usedThreads), {laneId});
751  }
752  return true;
753 }
754 
755 /// Sink out transfer_read op feeding into a warp op yield.
756 /// ```
757 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
758 /// ...
759 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
760 // vector<32xf32>
761 /// vector.yield %2 : vector<32xf32>
762 /// }
763 /// ```
764 /// To
765 /// ```
766 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
767 /// vector<1xf32>, vector<1xf32>) {
768 /// ...
769 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
770 /// vector<32xf32> vector.yield %2 : vector<32xf32>
771 /// }
772 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
773 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
775  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
776  PatternRewriter &rewriter) const override {
777  OpOperand *operand = getWarpResult(
778  warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
779  if (!operand)
780  return failure();
781  auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
782  // Don't duplicate transfer_read ops when distributing.
783  if (!read.getResult().hasOneUse())
784  return failure();
785  unsigned operandIndex = operand->getOperandNumber();
786  Value distributedVal = warpOp.getResult(operandIndex);
787 
788  SmallVector<Value, 4> indices(read.getIndices().begin(),
789  read.getIndices().end());
790  auto sequentialType = cast<VectorType>(read.getResult().getType());
791  auto distributedType = cast<VectorType>(distributedVal.getType());
792  AffineMap map = calculateImplicitMap(sequentialType, distributedType);
793  AffineMap indexMap = map.compose(read.getPermutationMap());
794  OpBuilder::InsertionGuard g(rewriter);
795  rewriter.setInsertionPointAfter(warpOp);
796 
797  // Try to delinearize the lane ID to match the rank expected for
798  // distribution.
799  SmallVector<Value> delinearizedIds;
800  if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(),
801  distributedType.getShape(), warpOp.getWarpSize(),
802  warpOp.getLaneid(), delinearizedIds))
803  return rewriter.notifyMatchFailure(
804  read, "cannot delinearize lane ID for distribution");
805  assert(!delinearizedIds.empty() || map.getNumResults() == 0);
806 
807  for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) {
808  AffineExpr d0, d1;
809  bindDims(read.getContext(), d0, d1);
810  auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
811  if (!indexExpr)
812  continue;
813  unsigned indexPos = indexExpr.getPosition();
814  unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
815  int64_t scale = distributedType.getDimSize(vectorPos);
816  indices[indexPos] = affine::makeComposedAffineApply(
817  rewriter, read.getLoc(), d0 + scale * d1,
818  {indices[indexPos], delinearizedIds[vectorPos]});
819  }
820  auto newRead = rewriter.create<vector::TransferReadOp>(
821  read.getLoc(), distributedVal.getType(), read.getSource(), indices,
822  read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
823  read.getInBoundsAttr());
824 
825  // Check that the produced operation is legal.
826  // The transfer op may be reading from values that are defined within
827  // warpOp's body, which is illegal.
828  // We do the check late because incdices may be changed by
829  // makeComposeAffineApply. This rewrite may remove dependencies from
830  // warOp's body.
831  // E.g., warop {
832  // %idx = affine.apply...[%outsideDef]
833  // ... = transfer_read ...[%idx]
834  // }
835  // will be rewritten in:
836  // warop {
837  // }
838  // %new_idx = affine.apply...[%outsideDef]
839  // ... = transfer_read ...[%new_idx]
840  if (!llvm::all_of(newRead->getOperands(), [&](Value value) {
841  return warpOp.isDefinedOutsideOfRegion(value);
842  }))
843  return failure();
844 
845  rewriter.replaceAllUsesWith(distributedVal, newRead);
846  return success();
847  }
848 };
849 
850 /// Remove any result that has no use along with the matching yieldOp operand.
851 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
852 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
854  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
855  PatternRewriter &rewriter) const override {
856  SmallVector<Type> newResultTypes;
857  newResultTypes.reserve(warpOp->getNumResults());
858  SmallVector<Value> newYieldValues;
859  newYieldValues.reserve(warpOp->getNumResults());
860  DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
861  DenseMap<OpResult, int64_t> dedupResultPositionMap;
862  auto yield = cast<vector::YieldOp>(
863  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
864 
865  // Some values may be yielded multiple times and correspond to multiple
866  // results. Deduplicating occurs by taking each result with its matching
867  // yielded value, and:
868  // 1. recording the unique first position at which the value is yielded.
869  // 2. recording for the result, the first position at which the dedup'ed
870  // value is yielded.
871  // 3. skipping from the new result types / new yielded values any result
872  // that has no use or whose yielded value has already been seen.
873  for (OpResult result : warpOp.getResults()) {
874  Value yieldOperand = yield.getOperand(result.getResultNumber());
875  auto it = dedupYieldOperandPositionMap.insert(
876  std::make_pair(yieldOperand, newResultTypes.size()));
877  dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
878  if (result.use_empty() || !it.second)
879  continue;
880  newResultTypes.push_back(result.getType());
881  newYieldValues.push_back(yieldOperand);
882  }
883  // No modification, exit early.
884  if (yield.getNumOperands() == newYieldValues.size())
885  return failure();
886  // Move the body of the old warpOp to a new warpOp.
887  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
888  rewriter, warpOp, newYieldValues, newResultTypes);
889  // Replace results of the old warpOp by the new, deduplicated results.
890  SmallVector<Value> newValues;
891  newValues.reserve(warpOp->getNumResults());
892  for (OpResult result : warpOp.getResults()) {
893  if (result.use_empty())
894  newValues.push_back(Value());
895  else
896  newValues.push_back(
897  newWarpOp.getResult(dedupResultPositionMap.lookup(result)));
898  }
899  rewriter.replaceOp(warpOp, newValues);
900  return success();
901  }
902 };
903 
904 // If an operand is directly yielded out of the region we can forward it
905 // directly and it doesn't need to go through the region.
906 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
908  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
909  PatternRewriter &rewriter) const override {
910  SmallVector<Type> resultTypes;
911  SmallVector<Value> yieldValues;
912  auto yield = cast<vector::YieldOp>(
913  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
914  Value valForwarded;
915  unsigned resultIndex;
916  for (OpOperand &operand : yield->getOpOperands()) {
917  Value result = warpOp.getResult(operand.getOperandNumber());
918  if (result.use_empty())
919  continue;
920 
921  // Assume all the values coming from above are uniform.
922  if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
923  if (result.getType() != operand.get().getType())
924  continue;
925  valForwarded = operand.get();
926  resultIndex = operand.getOperandNumber();
927  break;
928  }
929  auto arg = dyn_cast<BlockArgument>(operand.get());
930  if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
931  continue;
932  Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
933  if (result.getType() != warpOperand.getType())
934  continue;
935  valForwarded = warpOperand;
936  resultIndex = operand.getOperandNumber();
937  break;
938  }
939  if (!valForwarded)
940  return failure();
941  rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
942  return success();
943  }
944 };
945 
946 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
948  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
949  PatternRewriter &rewriter) const override {
950  OpOperand *operand = getWarpResult(
951  warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
952  if (!operand)
953  return failure();
954  unsigned int operandNumber = operand->getOperandNumber();
955  auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
956  Location loc = broadcastOp.getLoc();
957  auto destVecType =
958  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
959  Value broadcastSrc = broadcastOp.getSource();
960  Type broadcastSrcType = broadcastSrc.getType();
961 
962  // Check that the broadcast actually spans a set of values uniformly across
963  // all threads. In other words, check that each thread can reconstruct
964  // their own broadcast.
965  // For that we simply check that the broadcast we want to build makes sense.
966  if (vector::isBroadcastableTo(broadcastSrcType, destVecType) !=
968  return failure();
969  SmallVector<size_t> newRetIndices;
970  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
971  rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices);
972  rewriter.setInsertionPointAfter(newWarpOp);
973  Value broadcasted = rewriter.create<vector::BroadcastOp>(
974  loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
975  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
976  broadcasted);
977  return success();
978  }
979 };
980 
981 /// Pattern to move shape cast out of the warp op. shape cast is basically a
982 /// no-op for warp distribution; we need to handle the shape though.
983 struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> {
985  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
986  PatternRewriter &rewriter) const override {
987  OpOperand *operand = getWarpResult(
988  warpOp, [](Operation *op) { return isa<vector::ShapeCastOp>(op); });
989  if (!operand)
990  return failure();
991  auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>();
992 
993  unsigned int operandNumber = operand->getOperandNumber();
994  auto castDistributedType =
995  cast<VectorType>(warpOp->getResultTypes()[operandNumber]);
996  VectorType castOriginalType = oldCastOp.getSourceVectorType();
997  VectorType castResultType = castDistributedType;
998 
999  // We expect the distributed type to have a smaller rank than the original
1000  // type. Prepend with size-one dimensions to make them the same.
1001  unsigned castDistributedRank = castDistributedType.getRank();
1002  unsigned castOriginalRank = castOriginalType.getRank();
1003  if (castDistributedRank < castOriginalRank) {
1004  SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1);
1005  llvm::append_range(shape, castDistributedType.getShape());
1006  castDistributedType =
1007  VectorType::get(shape, castDistributedType.getElementType());
1008  }
1009 
1010  SmallVector<size_t> newRetIndices;
1011  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1012  rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType},
1013  newRetIndices);
1014  rewriter.setInsertionPointAfter(newWarpOp);
1015  Value newCast = rewriter.create<vector::ShapeCastOp>(
1016  oldCastOp.getLoc(), castResultType,
1017  newWarpOp->getResult(newRetIndices[0]));
1018  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast);
1019  return success();
1020  }
1021 };
1022 
1023 /// Pattern to move out vector.extract of single element vector. Those don't
1024 /// need to be distributed and can just be propagated outside of the region.
1025 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
1027  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1028  PatternRewriter &rewriter) const override {
1029  OpOperand *operand = getWarpResult(
1030  warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
1031  if (!operand)
1032  return failure();
1033  unsigned int operandNumber = operand->getOperandNumber();
1034  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
1035  VectorType extractSrcType = extractOp.getSourceVectorType();
1036  Location loc = extractOp.getLoc();
1037 
1038  // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
1039  assert(extractSrcType.getRank() > 0 &&
1040  "vector.extract does not support rank 0 sources");
1041 
1042  // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
1043  // canonicalized to %v.
1044  if (extractOp.getNumIndices() == 0)
1045  return failure();
1046 
1047  // Rewrite vector.extract with 1d source to vector.extractelement.
1048  if (extractSrcType.getRank() == 1) {
1049  if (extractOp.hasDynamicPosition())
1050  // TODO: Dinamic position not supported yet.
1051  return failure();
1052 
1053  assert(extractOp.getNumIndices() == 1 && "expected 1 index");
1054  int64_t pos = extractOp.getStaticPosition()[0];
1055  rewriter.setInsertionPoint(extractOp);
1056  rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
1057  extractOp, extractOp.getVector(),
1058  rewriter.create<arith::ConstantIndexOp>(loc, pos));
1059  return success();
1060  }
1061 
1062  // All following cases are 2d or higher dimensional source vectors.
1063 
1064  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1065  // There is no distribution, this is a broadcast. Simply move the extract
1066  // out of the warp op.
1067  // TODO: This could be optimized. E.g., in case of a scalar result, let
1068  // one lane extract and shuffle the result to all other lanes (same as
1069  // the 1d case).
1070  SmallVector<size_t> newRetIndices;
1071  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1072  rewriter, warpOp, {extractOp.getVector()},
1073  {extractOp.getSourceVectorType()}, newRetIndices);
1074  rewriter.setInsertionPointAfter(newWarpOp);
1075  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1076  // Extract from distributed vector.
1077  Value newExtract = rewriter.create<vector::ExtractOp>(
1078  loc, distributedVec, extractOp.getMixedPosition());
1079  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1080  newExtract);
1081  return success();
1082  }
1083 
1084  // Find the distributed dimension. There should be exactly one.
1085  auto distributedType =
1086  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1087  auto yieldedType = cast<VectorType>(operand->get().getType());
1088  int64_t distributedDim = -1;
1089  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1090  if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) {
1091  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1092  // support distributing multiple dimensions in the future.
1093  assert(distributedDim == -1 && "found multiple distributed dims");
1094  distributedDim = i;
1095  }
1096  }
1097  assert(distributedDim != -1 && "could not find distributed dimension");
1098  (void)distributedDim;
1099 
1100  // Yield source vector from warp op.
1101  SmallVector<int64_t> newDistributedShape(extractSrcType.getShape().begin(),
1102  extractSrcType.getShape().end());
1103  for (int i = 0; i < distributedType.getRank(); ++i)
1104  newDistributedShape[i + extractOp.getNumIndices()] =
1105  distributedType.getDimSize(i);
1106  auto newDistributedType =
1107  VectorType::get(newDistributedShape, distributedType.getElementType());
1108  SmallVector<size_t> newRetIndices;
1109  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1110  rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
1111  newRetIndices);
1112  rewriter.setInsertionPointAfter(newWarpOp);
1113  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1114  // Extract from distributed vector.
1115  Value newExtract = rewriter.create<vector::ExtractOp>(
1116  loc, distributedVec, extractOp.getMixedPosition());
1117  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1118  newExtract);
1119  return success();
1120  }
1121 };
1122 
1123 /// Pattern to move out vector.extractelement of 0-D tensors. Those don't
1124 /// need to be distributed and can just be propagated outside of the region.
1125 struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1126  WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1127  PatternBenefit b = 1)
1128  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1129  warpShuffleFromIdxFn(std::move(fn)) {}
1130  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1131  PatternRewriter &rewriter) const override {
1132  OpOperand *operand = getWarpResult(warpOp, [](Operation *op) {
1133  return isa<vector::ExtractElementOp>(op);
1134  });
1135  if (!operand)
1136  return failure();
1137  unsigned int operandNumber = operand->getOperandNumber();
1138  auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1139  VectorType extractSrcType = extractOp.getSourceVectorType();
1140  bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
1141  Type elType = extractSrcType.getElementType();
1142  VectorType distributedVecType;
1143  if (!is0dOrVec1Extract) {
1144  assert(extractSrcType.getRank() == 1 &&
1145  "expected that extractelement src rank is 0 or 1");
1146  if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
1147  return failure();
1148  int64_t elementsPerLane =
1149  extractSrcType.getShape()[0] / warpOp.getWarpSize();
1150  distributedVecType = VectorType::get({elementsPerLane}, elType);
1151  } else {
1152  distributedVecType = extractSrcType;
1153  }
1154  // Yield source vector from warp op.
1155  Location loc = extractOp.getLoc();
1156  SmallVector<size_t> newRetIndices;
1157  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1158  rewriter, warpOp, {extractOp.getVector()}, {distributedVecType},
1159  newRetIndices);
1160  rewriter.setInsertionPointAfter(newWarpOp);
1161  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1162 
1163  // 0d extract: The new warp op broadcasts the source vector to all lanes.
1164  // All lanes extract the scalar.
1165  if (is0dOrVec1Extract) {
1166  Value newExtract;
1167  if (extractSrcType.getRank() == 1) {
1168  newExtract = rewriter.create<vector::ExtractElementOp>(
1169  loc, distributedVec,
1170  rewriter.create<arith::ConstantIndexOp>(loc, 0));
1171 
1172  } else {
1173  newExtract =
1174  rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
1175  }
1176  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1177  newExtract);
1178  return success();
1179  }
1180 
1181  // 1d extract: Distribute the source vector. One lane extracts and shuffles
1182  // the value to all other lanes.
1183  int64_t elementsPerLane = distributedVecType.getShape()[0];
1184  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1185  // tid of extracting thread: pos / elementsPerLane
1186  Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
1187  loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition());
1188  // Extract at position: pos % elementsPerLane
1189  Value pos =
1190  elementsPerLane == 1
1191  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1192  : rewriter
1193  .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1194  extractOp.getPosition())
1195  .getResult();
1196  Value extracted =
1197  rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
1198 
1199  // Shuffle the extracted value to all lanes.
1200  Value shuffled = warpShuffleFromIdxFn(
1201  loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize());
1202  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled);
1203  return success();
1204  }
1205 
1206 private:
1207  WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1208 };
1209 
1210 struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1212 
1213  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1214  PatternRewriter &rewriter) const override {
1215  OpOperand *operand = getWarpResult(
1216  warpOp, [](Operation *op) { return isa<vector::InsertElementOp>(op); });
1217  if (!operand)
1218  return failure();
1219  unsigned int operandNumber = operand->getOperandNumber();
1220  auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1221  VectorType vecType = insertOp.getDestVectorType();
1222  VectorType distrType =
1223  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1224  bool hasPos = static_cast<bool>(insertOp.getPosition());
1225 
1226  // Yield destination vector, source scalar and position from warp op.
1227  SmallVector<Value> additionalResults{insertOp.getDest(),
1228  insertOp.getSource()};
1229  SmallVector<Type> additionalResultTypes{distrType,
1230  insertOp.getSource().getType()};
1231  if (hasPos) {
1232  additionalResults.push_back(insertOp.getPosition());
1233  additionalResultTypes.push_back(insertOp.getPosition().getType());
1234  }
1235  Location loc = insertOp.getLoc();
1236  SmallVector<size_t> newRetIndices;
1237  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1238  rewriter, warpOp, additionalResults, additionalResultTypes,
1239  newRetIndices);
1240  rewriter.setInsertionPointAfter(newWarpOp);
1241  Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
1242  Value newSource = newWarpOp->getResult(newRetIndices[1]);
1243  Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
1244  rewriter.setInsertionPointAfter(newWarpOp);
1245 
1246  if (vecType == distrType) {
1247  // Broadcast: Simply move the vector.inserelement op out.
1248  Value newInsert = rewriter.create<vector::InsertElementOp>(
1249  loc, newSource, distributedVec, newPos);
1250  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1251  newInsert);
1252  return success();
1253  }
1254 
1255  // This is a distribution. Only one lane should insert.
1256  int64_t elementsPerLane = distrType.getShape()[0];
1257  AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
1258  // tid of extracting thread: pos / elementsPerLane
1259  Value insertingLane = rewriter.create<affine::AffineApplyOp>(
1260  loc, sym0.ceilDiv(elementsPerLane), newPos);
1261  // Insert position: pos % elementsPerLane
1262  Value pos =
1263  elementsPerLane == 1
1264  ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1265  : rewriter
1266  .create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1267  newPos)
1268  .getResult();
1269  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1270  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1271  Value newResult =
1272  rewriter
1273  .create<scf::IfOp>(
1274  loc, isInsertingLane,
1275  /*thenBuilder=*/
1276  [&](OpBuilder &builder, Location loc) {
1277  Value newInsert = builder.create<vector::InsertElementOp>(
1278  loc, newSource, distributedVec, pos);
1279  builder.create<scf::YieldOp>(loc, newInsert);
1280  },
1281  /*elseBuilder=*/
1282  [&](OpBuilder &builder, Location loc) {
1283  builder.create<scf::YieldOp>(loc, distributedVec);
1284  })
1285  .getResult(0);
1286  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1287  return success();
1288  }
1289 };
1290 
1291 struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
1293 
1294  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1295  PatternRewriter &rewriter) const override {
1296  OpOperand *operand = getWarpResult(
1297  warpOp, [](Operation *op) { return isa<vector::InsertOp>(op); });
1298  if (!operand)
1299  return failure();
1300  unsigned int operandNumber = operand->getOperandNumber();
1301  auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
1302  Location loc = insertOp.getLoc();
1303 
1304  // "vector.insert %v, %v[] : ..." can be canonicalized to %v.
1305  if (insertOp.getNumIndices() == 0)
1306  return failure();
1307 
1308  // Rewrite vector.insert with 1d dest to vector.insertelement.
1309  if (insertOp.getDestVectorType().getRank() == 1) {
1310  if (insertOp.hasDynamicPosition())
1311  // TODO: Dinamic position not supported yet.
1312  return failure();
1313 
1314  assert(insertOp.getNumIndices() == 1 && "expected 1 index");
1315  int64_t pos = insertOp.getStaticPosition()[0];
1316  rewriter.setInsertionPoint(insertOp);
1317  rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
1318  insertOp, insertOp.getSource(), insertOp.getDest(),
1319  rewriter.create<arith::ConstantIndexOp>(loc, pos));
1320  return success();
1321  }
1322 
1323  if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
1324  // There is no distribution, this is a broadcast. Simply move the insert
1325  // out of the warp op.
1326  SmallVector<size_t> newRetIndices;
1327  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1328  rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1329  {insertOp.getSourceType(), insertOp.getDestVectorType()},
1330  newRetIndices);
1331  rewriter.setInsertionPointAfter(newWarpOp);
1332  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1333  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1334  Value newResult = rewriter.create<vector::InsertOp>(
1335  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1336  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
1337  newResult);
1338  return success();
1339  }
1340 
1341  // Find the distributed dimension. There should be exactly one.
1342  auto distrDestType =
1343  cast<VectorType>(warpOp.getResult(operandNumber).getType());
1344  auto yieldedType = cast<VectorType>(operand->get().getType());
1345  int64_t distrDestDim = -1;
1346  for (int64_t i = 0; i < yieldedType.getRank(); ++i) {
1347  if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) {
1348  // Keep this assert here in case WarpExecuteOnLane0Op gets extended to
1349  // support distributing multiple dimensions in the future.
1350  assert(distrDestDim == -1 && "found multiple distributed dims");
1351  distrDestDim = i;
1352  }
1353  }
1354  assert(distrDestDim != -1 && "could not find distributed dimension");
1355 
1356  // Compute the distributed source vector type.
1357  VectorType srcVecType = cast<VectorType>(insertOp.getSourceType());
1358  SmallVector<int64_t> distrSrcShape(srcVecType.getShape().begin(),
1359  srcVecType.getShape().end());
1360  // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32>
1361  // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will
1362  // insert a smaller vector<3xf32>.
1363  // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that
1364  // case, one lane will insert the source vector<96xf32>. The other
1365  // lanes will not do anything.
1366  int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices();
1367  if (distrSrcDim >= 0)
1368  distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim);
1369  auto distrSrcType =
1370  VectorType::get(distrSrcShape, distrDestType.getElementType());
1371 
1372  // Yield source and dest vectors from warp op.
1373  SmallVector<size_t> newRetIndices;
1374  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1375  rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()},
1376  {distrSrcType, distrDestType}, newRetIndices);
1377  rewriter.setInsertionPointAfter(newWarpOp);
1378  Value distributedSrc = newWarpOp->getResult(newRetIndices[0]);
1379  Value distributedDest = newWarpOp->getResult(newRetIndices[1]);
1380 
1381  // Insert into the distributed vector.
1382  Value newResult;
1383  if (distrSrcDim >= 0) {
1384  // Every lane inserts a small piece.
1385  newResult = rewriter.create<vector::InsertOp>(
1386  loc, distributedSrc, distributedDest, insertOp.getMixedPosition());
1387  } else {
1388  // One lane inserts the entire source vector.
1389  int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim);
1390  SmallVector<OpFoldResult> pos = insertOp.getMixedPosition();
1391  SmallVector<int64_t> newPos = getAsIntegers(pos);
1392  // tid of inserting lane: pos / elementsPerLane
1393  Value insertingLane = rewriter.create<arith::ConstantIndexOp>(
1394  loc, newPos[distrDestDim] / elementsPerLane);
1395  Value isInsertingLane = rewriter.create<arith::CmpIOp>(
1396  loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
1397  // Insert position: pos % elementsPerLane
1398  newPos[distrDestDim] %= elementsPerLane;
1399  auto insertingBuilder = [&](OpBuilder &builder, Location loc) {
1400  Value newInsert = builder.create<vector::InsertOp>(
1401  loc, distributedSrc, distributedDest, newPos);
1402  builder.create<scf::YieldOp>(loc, newInsert);
1403  };
1404  auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) {
1405  builder.create<scf::YieldOp>(loc, distributedDest);
1406  };
1407  newResult = rewriter
1408  .create<scf::IfOp>(loc, isInsertingLane,
1409  /*thenBuilder=*/insertingBuilder,
1410  /*elseBuilder=*/nonInsertingBuilder)
1411  .getResult(0);
1412  }
1413 
1414  rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult);
1415  return success();
1416  }
1417 };
1418 
1419 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1420 /// the scf.ForOp is the last operation in the region so that it doesn't change
1421 /// the order of execution. This creates a new scf.for region after the
1422 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
1423 /// WarpExecuteOnLane0Op region. Example:
1424 /// ```
1425 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
1426 /// ...
1427 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
1428 /// -> (vector<128xf32>) {
1429 /// ...
1430 /// scf.yield %r : vector<128xf32>
1431 /// }
1432 /// vector.yield %v1 : vector<128xf32>
1433 /// }
1434 /// ```
1435 /// To:
1436 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
1437 /// ...
1438 /// vector.yield %v : vector<128xf32>
1439 /// }
1440 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
1441 /// -> (vector<4xf32>) {
1442 /// %iw = vector.warp_execute_on_lane_0(%laneid)
1443 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
1444 /// ^bb0(%arg: vector<128xf32>):
1445 /// ...
1446 /// vector.yield %ir : vector<128xf32>
1447 /// }
1448 /// scf.yield %iw : vector<4xf32>
1449 /// }
1450 /// ```
1451 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
1452 
1453  WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1)
1454  : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1455  distributionMapFn(std::move(fn)) {}
1457  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1458  PatternRewriter &rewriter) const override {
1459  auto yield = cast<vector::YieldOp>(
1460  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1461  // Only pick up forOp if it is the last op in the region.
1462  Operation *lastNode = yield->getPrevNode();
1463  auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
1464  if (!forOp)
1465  return failure();
1466  // Collect Values that come from the warp op but are outside the forOp.
1467  // Those Value needs to be returned by the original warpOp and passed to the
1468  // new op.
1469  llvm::SmallSetVector<Value, 32> escapingValues;
1470  SmallVector<Type> inputTypes;
1471  SmallVector<Type> distTypes;
1473  forOp.getBodyRegion(), [&](OpOperand *operand) {
1474  Operation *parent = operand->get().getParentRegion()->getParentOp();
1475  if (warpOp->isAncestor(parent)) {
1476  if (!escapingValues.insert(operand->get()))
1477  return;
1478  Type distType = operand->get().getType();
1479  if (auto vecType = dyn_cast<VectorType>(distType)) {
1480  AffineMap map = distributionMapFn(operand->get());
1481  distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1482  }
1483  inputTypes.push_back(operand->get().getType());
1484  distTypes.push_back(distType);
1485  }
1486  });
1487 
1488  SmallVector<size_t> newRetIndices;
1489  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1490  rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
1491  newRetIndices);
1492  yield = cast<vector::YieldOp>(
1493  newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
1494 
1495  SmallVector<Value> newOperands;
1496  SmallVector<unsigned> resultIdx;
1497  // Collect all the outputs coming from the forOp.
1498  for (OpOperand &yieldOperand : yield->getOpOperands()) {
1499  if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
1500  continue;
1501  auto forResult = cast<OpResult>(yieldOperand.get());
1502  newOperands.push_back(
1503  newWarpOp.getResult(yieldOperand.getOperandNumber()));
1504  yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
1505  resultIdx.push_back(yieldOperand.getOperandNumber());
1506  }
1507 
1508  OpBuilder::InsertionGuard g(rewriter);
1509  rewriter.setInsertionPointAfter(newWarpOp);
1510 
1511  // Create a new for op outside the region with a WarpExecuteOnLane0Op region
1512  // inside.
1513  auto newForOp = rewriter.create<scf::ForOp>(
1514  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
1515  forOp.getStep(), newOperands);
1516  rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
1517 
1518  SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
1519  newForOp.getRegionIterArgs().end());
1520  SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
1521  forOp.getResultTypes().end());
1522  llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
1523  for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1524  warpInput.push_back(newWarpOp.getResult(retIdx));
1525  argIndexMapping[escapingValues[i]] = warpInputType.size();
1526  warpInputType.push_back(inputTypes[i]);
1527  }
1528  auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
1529  newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
1530  newWarpOp.getWarpSize(), warpInput, warpInputType);
1531 
1532  SmallVector<Value> argMapping;
1533  argMapping.push_back(newForOp.getInductionVar());
1534  for (Value args : innerWarp.getBody()->getArguments()) {
1535  argMapping.push_back(args);
1536  }
1537  argMapping.resize(forOp.getBody()->getNumArguments());
1538  SmallVector<Value> yieldOperands;
1539  for (Value operand : forOp.getBody()->getTerminator()->getOperands())
1540  yieldOperands.push_back(operand);
1541  rewriter.eraseOp(forOp.getBody()->getTerminator());
1542  rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
1543  rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
1544  rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
1545  rewriter.setInsertionPointAfter(innerWarp);
1546  if (!innerWarp.getResults().empty())
1547  rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
1548  rewriter.eraseOp(forOp);
1549  // Replace the warpOp result coming from the original ForOp.
1550  for (const auto &res : llvm::enumerate(resultIdx)) {
1551  rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
1552  newForOp.getResult(res.index()));
1553  newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
1554  }
1555  newForOp.walk([&](Operation *op) {
1556  for (OpOperand &operand : op->getOpOperands()) {
1557  auto it = argIndexMapping.find(operand.get());
1558  if (it == argIndexMapping.end())
1559  continue;
1560  operand.set(innerWarp.getBodyRegion().getArgument(it->second));
1561  }
1562  });
1563  return success();
1564  }
1565 
1566 private:
1567  DistributionMapFn distributionMapFn;
1568 };
1569 
1570 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1571 /// The vector is reduced in parallel. Currently limited to vector size matching
1572 /// the warpOp size. E.g.:
1573 /// ```
1574 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1575 /// %0 = "some_def"() : () -> (vector<32xf32>)
1576 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
1577 /// vector_ext.yield %1 : f32
1578 /// }
1579 /// ```
1580 /// is lowered to:
1581 /// ```
1582 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
1583 /// %1 = "some_def"() : () -> (vector<32xf32>)
1584 /// vector_ext.yield %1 : vector<32xf32>
1585 /// }
1586 /// %a = vector.extract %0[0] : f32 from vector<1xf32>
1587 /// %r = ("warp.reduction %a")
1588 /// ```
1589 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
1590  WarpOpReduction(MLIRContext *context,
1591  DistributedReductionFn distributedReductionFn,
1592  PatternBenefit benefit = 1)
1593  : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
1594  distributedReductionFn(std::move(distributedReductionFn)) {}
1595 
1596  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1597  PatternRewriter &rewriter) const override {
1598  OpOperand *yieldOperand = getWarpResult(
1599  warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
1600  if (!yieldOperand)
1601  return failure();
1602 
1603  auto reductionOp =
1604  cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
1605  auto vectorType = cast<VectorType>(reductionOp.getVector().getType());
1606  // Only rank 1 vectors supported.
1607  if (vectorType.getRank() != 1)
1608  return rewriter.notifyMatchFailure(
1609  warpOp, "Only rank 1 reductions can be distributed.");
1610  // Only warp_size-sized vectors supported.
1611  if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
1612  return rewriter.notifyMatchFailure(
1613  warpOp, "Reduction vector dimension must match was size.");
1614  if (!reductionOp.getType().isIntOrFloat())
1615  return rewriter.notifyMatchFailure(
1616  warpOp, "Reduction distribution currently only supports floats and "
1617  "integer types.");
1618 
1619  int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
1620  // Return vector that will be reduced from the WarpExecuteOnLane0Op.
1621  unsigned operandIndex = yieldOperand->getOperandNumber();
1622  SmallVector<Value> yieldValues = {reductionOp.getVector()};
1623  SmallVector<Type> retTypes = {
1624  VectorType::get({numElements}, reductionOp.getType())};
1625  if (reductionOp.getAcc()) {
1626  yieldValues.push_back(reductionOp.getAcc());
1627  retTypes.push_back(reductionOp.getAcc().getType());
1628  }
1629  SmallVector<size_t> newRetIndices;
1630  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1631  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
1632  rewriter.setInsertionPointAfter(newWarpOp);
1633 
1634  // Obtain data to reduce for a single lane.
1635  Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
1636  // Distribute and reduce across threads.
1637  Value fullReduce =
1638  distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec,
1639  reductionOp.getKind(), newWarpOp.getWarpSize());
1640  if (reductionOp.getAcc()) {
1641  fullReduce = vector::makeArithReduction(
1642  rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
1643  newWarpOp.getResult(newRetIndices[1]));
1644  }
1645  rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce);
1646  return success();
1647  }
1648 
1649 private:
1650  DistributedReductionFn distributedReductionFn;
1651 };
1652 
1653 } // namespace
1654 
1656  RewritePatternSet &patterns,
1658  patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit);
1659 }
1660 
1661 void mlir::vector::populateDistributeTransferWriteOpPatterns(
1662  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1663  PatternBenefit benefit) {
1664  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn,
1665  benefit);
1666 }
1667 
1668 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1669  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
1670  const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) {
1671  patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
1672  WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract,
1673  WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement,
1674  WarpOpInsert>(patterns.getContext(), benefit);
1675  patterns.add<WarpOpExtractElement>(patterns.getContext(),
1676  warpShuffleFromIdxFn, benefit);
1677  patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
1678  benefit);
1679 }
1680 
1681 void mlir::vector::populateDistributeReduction(
1682  RewritePatternSet &patterns,
1683  const DistributedReductionFn &distributedReductionFn,
1684  PatternBenefit benefit) {
1685  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn,
1686  benefit);
1687 }
1688 
1689 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
1690  Block *body = warpOp.getBody();
1691 
1692  // Keep track of the ops we want to hoist.
1693  llvm::SmallSetVector<Operation *, 8> opsToMove;
1694 
1695  // Helper to check if a value is or will be defined outside of the region.
1696  auto isDefinedOutsideOfBody = [&](Value value) {
1697  auto *definingOp = value.getDefiningOp();
1698  return (definingOp && opsToMove.count(definingOp)) ||
1699  warpOp.isDefinedOutsideOfRegion(value);
1700  };
1701 
1702  // Do not use walk here, as we do not want to go into nested regions and hoist
1703  // operations from there.
1704  for (auto &op : body->without_terminator()) {
1705  bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
1706  return isa<VectorType>(result.getType());
1707  });
1708  if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
1709  opsToMove.insert(&op);
1710  }
1711 
1712  // Move all the ops marked as uniform outside of the region.
1713  for (Operation *op : opsToMove)
1714  op->moveBefore(warpOp);
1715 }
static llvm::ManagedStatic< PassManagerOptions > options
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType)
Currently the distribution map is implicit based on the vector shape.
static OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, const std::function< bool(Operation *)> &fn)
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes)
Helper to create a new WarpExecuteOnLane0Op with different signature.
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, llvm::SmallVector< size_t > &indices)
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
unsigned getPosition() const
Definition: AffineExpr.cpp:325
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:786
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:829
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:358
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:350
unsigned getNumResults() const
Definition: AffineMap.cpp:345
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:495
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
Operation & front()
Definition: Block.h:146
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:202
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:361
MLIRContext * getContext() const
Definition: Builders.h:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:150
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:153
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:421
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents an operand of an operation.
Definition: Value.h:261
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:217
This is a value defined by a result of an operation.
Definition: Value.h:448
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:554
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
BlockListType & getBlocks()
Definition: Region.h:45
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:212
Type getType() const
Return the type of this value.
Definition: Value.h:122
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1344
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1229
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:2120
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
Definition: VectorOps.cpp:228
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, Value mask=Value())
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
std::function< AffineMap(Value)> DistributionMapFn
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:331
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:345
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:36
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:512
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.