MLIR  16.0.0git
VectorDistribute.cpp
Go to the documentation of this file.
1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "llvm/ADT/SetVector.h"
18 #include <utility>
19 
20 using namespace mlir;
21 using namespace mlir::vector;
22 
23 static LogicalResult
24 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
26  assert(warpOp.getBodyRegion().hasOneBlock() &&
27  "expected WarpOp with single block");
28  Block *warpOpBody = &warpOp.getBodyRegion().front();
29  Location loc = warpOp.getLoc();
30 
31  // Passed all checks. Start rewriting.
32  OpBuilder::InsertionGuard g(rewriter);
33  rewriter.setInsertionPoint(warpOp);
34 
35  // Create scf.if op.
36  Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
37  Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
38  warpOp.getLaneid(), c0);
39  auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0,
40  /*withElseRegion=*/false);
41  rewriter.eraseOp(ifOp.thenBlock()->getTerminator());
42 
43  // Store vectors that are defined outside of warpOp into the scratch pad
44  // buffer.
45  SmallVector<Value> bbArgReplacements;
46  for (const auto &it : llvm::enumerate(warpOp.getArgs())) {
47  Value val = it.value();
48  Value bbArg = warpOpBody->getArgument(it.index());
49 
50  rewriter.setInsertionPoint(ifOp);
51  Value buffer =
52  options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType());
53 
54  // Store arg vector into buffer.
55  rewriter.setInsertionPoint(ifOp);
56  auto vectorType = val.getType().cast<VectorType>();
57  int64_t storeSize = vectorType.getShape()[0];
58  Value storeOffset = rewriter.create<arith::MulIOp>(
59  loc, warpOp.getLaneid(),
60  rewriter.create<arith::ConstantIndexOp>(loc, storeSize));
61  rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset);
62 
63  // Load bbArg vector from buffer.
64  rewriter.setInsertionPointToStart(ifOp.thenBlock());
65  auto bbArgType = bbArg.getType().cast<VectorType>();
66  Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0);
67  bbArgReplacements.push_back(loadOp);
68  }
69 
70  // Insert sync after all the stores and before all the loads.
71  if (!warpOp.getArgs().empty()) {
72  rewriter.setInsertionPoint(ifOp);
73  options.warpSyncronizationFn(loc, rewriter, warpOp);
74  }
75 
76  // Move body of warpOp to ifOp.
77  rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements);
78 
79  // Rewrite terminator and compute replacements of WarpOp results.
80  SmallVector<Value> replacements;
81  auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
82  Location yieldLoc = yieldOp.getLoc();
83  for (const auto &it : llvm::enumerate(yieldOp.operands())) {
84  Value val = it.value();
85  Type resultType = warpOp->getResultTypes()[it.index()];
86  rewriter.setInsertionPoint(ifOp);
87  Value buffer =
88  options.warpAllocationFn(loc, rewriter, warpOp, val.getType());
89 
90  // Store yielded value into buffer.
91  rewriter.setInsertionPoint(yieldOp);
92  if (val.getType().isa<VectorType>())
93  rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0);
94  else
95  rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0);
96 
97  // Load value from buffer (after warpOp).
98  rewriter.setInsertionPointAfter(ifOp);
99  if (resultType == val.getType()) {
100  // Result type and yielded value type are the same. This is a broadcast.
101  // E.g.:
102  // %r = vector.warp_execute_on_lane_0(...) -> (f32) {
103  // vector.yield %cst : f32
104  // }
105  // Both types are f32. The constant %cst is broadcasted to all lanes.
106  // This is described in more detail in the documentation of the op.
107  Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0);
108  replacements.push_back(loadOp);
109  } else {
110  auto loadedVectorType = resultType.cast<VectorType>();
111  int64_t loadSize = loadedVectorType.getShape()[0];
112 
113  // loadOffset = laneid * loadSize
114  Value loadOffset = rewriter.create<arith::MulIOp>(
115  loc, warpOp.getLaneid(),
116  rewriter.create<arith::ConstantIndexOp>(loc, loadSize));
117  Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType,
118  buffer, loadOffset);
119  replacements.push_back(loadOp);
120  }
121  }
122 
123  // Insert sync after all the stores and before all the loads.
124  if (!yieldOp.operands().empty()) {
125  rewriter.setInsertionPointAfter(ifOp);
126  options.warpSyncronizationFn(loc, rewriter, warpOp);
127  }
128 
129  // Delete terminator and add empty scf.yield.
130  rewriter.eraseOp(yieldOp);
131  rewriter.setInsertionPointToEnd(ifOp.thenBlock());
132  rewriter.create<scf::YieldOp>(yieldLoc);
133 
134  // Compute replacements for WarpOp results.
135  rewriter.replaceOp(warpOp, replacements);
136 
137  return success();
138 }
139 
140 /// Helper to create a new WarpExecuteOnLane0Op with different signature.
141 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
142  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
143  ValueRange newYieldedValues, TypeRange newReturnTypes) {
144  // Create a new op before the existing one, with the extra operands.
145  OpBuilder::InsertionGuard g(rewriter);
146  rewriter.setInsertionPoint(warpOp);
147  auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
148  warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(),
149  warpOp.getArgs(), warpOp.getBody()->getArgumentTypes());
150 
151  Region &opBody = warpOp.getBodyRegion();
152  Region &newOpBody = newWarpOp.getBodyRegion();
153  Block &newOpFirstBlock = newOpBody.front();
154  rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin());
155  rewriter.eraseBlock(&newOpFirstBlock);
156  assert(newWarpOp.getWarpRegion().hasOneBlock() &&
157  "expected WarpOp with single block");
158 
159  auto yield =
160  cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
161 
162  rewriter.updateRootInPlace(
163  yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
164  return newWarpOp;
165 }
166 
167 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
168 /// `indices` return the index of each new output.
169 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(
170  RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp,
171  ValueRange newYieldedValues, TypeRange newReturnTypes,
172  llvm::SmallVector<size_t> &indices) {
173  SmallVector<Type> types(warpOp.getResultTypes().begin(),
174  warpOp.getResultTypes().end());
175  auto yield = cast<vector::YieldOp>(
176  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
177  llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
178  yield.getOperands().end());
179  for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) {
180  if (yieldValues.insert(std::get<0>(newRet))) {
181  types.push_back(std::get<1>(newRet));
182  indices.push_back(yieldValues.size() - 1);
183  } else {
184  // If the value already exit the region don't create a new output.
185  for (auto &yieldOperand : llvm::enumerate(yieldValues.getArrayRef())) {
186  if (yieldOperand.value() == std::get<0>(newRet)) {
187  indices.push_back(yieldOperand.index());
188  break;
189  }
190  }
191  }
192  }
193  yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end());
194  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
195  rewriter, warpOp, yieldValues.getArrayRef(), types);
196  rewriter.replaceOp(warpOp,
197  newWarpOp.getResults().take_front(warpOp.getNumResults()));
198  return newWarpOp;
199 }
200 
201 /// Helper to know if an op can be hoisted out of the region.
202 static bool canBeHoisted(Operation *op,
203  function_ref<bool(Value)> definedOutside) {
204  return llvm::all_of(op->getOperands(), definedOutside) &&
205  isSideEffectFree(op) && op->getNumRegions() == 0;
206 }
207 
208 /// Return a value yielded by `warpOp` which statifies the filter lamdba
209 /// condition and is not dead.
210 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp,
211  std::function<bool(Operation *)> fn) {
212  auto yield = cast<vector::YieldOp>(
213  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
214  for (OpOperand &yieldOperand : yield->getOpOperands()) {
215  Value yieldValues = yieldOperand.get();
216  Operation *definedOp = yieldValues.getDefiningOp();
217  if (definedOp && fn(definedOp)) {
218  if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty())
219  return &yieldOperand;
220  }
221  }
222  return {};
223 }
224 
225 // Clones `op` into a new operation that takes `operands` and returns
226 // `resultTypes`.
228  Location loc, Operation *op,
229  ArrayRef<Value> operands,
230  ArrayRef<Type> resultTypes) {
231  OperationState res(loc, op->getName().getStringRef(), operands, resultTypes,
232  op->getAttrs());
233  return rewriter.create(res);
234 }
235 
236 /// Currently the distribution map is implicit based on the vector shape. In the
237 /// future it will be part of the op.
238 /// Example:
239 /// ```
240 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) {
241 /// ...
242 /// vector.yield %3 : vector<32x16x64xf32>
243 /// }
244 /// ```
245 /// Would have an implicit map of:
246 /// `(d0, d1, d2) -> (d0, d2)`
248  auto srcType = yield.getType().cast<VectorType>();
249  auto dstType = ret.getType().cast<VectorType>();
251  // Check which dimensions of the yield value are different than the dimensions
252  // of the result to know the distributed dimensions. Then associate each
253  // distributed dimension to an ID in order.
254  for (unsigned i = 0, e = srcType.getRank(); i < e; i++) {
255  if (srcType.getDimSize(i) != dstType.getDimSize(i))
256  perm.push_back(getAffineDimExpr(i, yield.getContext()));
257  }
258  auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext());
259  return map;
260 }
261 
262 namespace {
263 
264 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
265  WarpOpToScfForPattern(MLIRContext *context,
267  PatternBenefit benefit = 1)
268  : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
269  options(options) {}
270 
271  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
272  PatternRewriter &rewriter) const override {
273  return rewriteWarpOpToScfFor(rewriter, warpOp, options);
274  }
275 
276 private:
278 };
279 
280 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute
281 /// op with the proper return type.
282 /// The new write op is updated to write the result of the new warp execute op.
283 /// The old `writeOp` is deleted.
284 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter,
285  WarpExecuteOnLane0Op warpOp,
286  vector::TransferWriteOp writeOp,
287  VectorType targetType) {
288  assert(writeOp->getParentOp() == warpOp &&
289  "write must be nested immediately under warp");
290  OpBuilder::InsertionGuard g(rewriter);
291  SmallVector<size_t> newRetIndices;
292  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
293  rewriter, warpOp, ValueRange{{writeOp.getVector()}},
294  TypeRange{targetType}, newRetIndices);
295  rewriter.setInsertionPointAfter(newWarpOp);
296  auto newWriteOp =
297  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
298  rewriter.eraseOp(writeOp);
299  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
300  return newWriteOp;
301 }
302 
303 /// Distribute transfer_write ops based on the affine map returned by
304 /// `distributionMapFn`.
305 /// Example:
306 /// ```
307 /// %0 = vector.warp_execute_on_lane_0(%id){
308 /// ...
309 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
310 /// vector.yield
311 /// }
312 /// ```
313 /// To
314 /// ```
315 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
316 /// ...
317 /// vector.yield %v : vector<32xf32>
318 /// }
319 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
320 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> {
321  WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn,
322  PatternBenefit b = 1)
324  distributionMapFn(std::move(fn)) {}
325 
326  /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that
327  /// are multiples of the distribution ratio are supported at the moment.
328  LogicalResult tryDistributeOp(RewriterBase &rewriter,
329  vector::TransferWriteOp writeOp,
330  WarpExecuteOnLane0Op warpOp) const {
331  VectorType writtenVectorType = writeOp.getVectorType();
332 
333  // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op
334  // to separate it from the rest.
335  if (writtenVectorType.getRank() == 0)
336  return failure();
337 
338  // 2. Compute the distribution map.
339  AffineMap map = distributionMapFn(writeOp);
340  if (map.getNumResults() != 1)
341  return writeOp->emitError("multi-dim distribution not implemented yet");
342 
343  // 3. Compute the targetType using the distribution map.
344  SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(),
345  writtenVectorType.getShape().end());
346  for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
347  unsigned position = map.getDimPosition(i);
348  if (targetShape[position] % warpOp.getWarpSize() != 0)
349  return failure();
350  targetShape[position] = targetShape[position] / warpOp.getWarpSize();
351  }
352  VectorType targetType =
353  VectorType::get(targetShape, writtenVectorType.getElementType());
354 
355  // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from
356  // the rest.
357  vector::TransferWriteOp newWriteOp =
358  cloneWriteOp(rewriter, warpOp, writeOp, targetType);
359 
360  // 5. Reindex the write using the distribution map.
361  auto newWarpOp =
362  newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>();
363  rewriter.setInsertionPoint(newWriteOp);
364  AffineMap indexMap = map.compose(newWriteOp.getPermutationMap());
365  Location loc = newWriteOp.getLoc();
366  SmallVector<Value> indices(newWriteOp.getIndices().begin(),
367  newWriteOp.getIndices().end());
368  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
369  AffineExpr d0, d1;
370  bindDims(newWarpOp.getContext(), d0, d1);
371  auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
372  if (!indexExpr)
373  continue;
374  unsigned indexPos = indexExpr.getPosition();
375  unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
376  auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]);
377  indices[indexPos] =
378  makeComposedAffineApply(rewriter, loc, d0 + scale * d1,
379  {indices[indexPos], newWarpOp.getLaneid()});
380  }
381  newWriteOp.getIndicesMutable().assign(indices);
382 
383  return success();
384  }
385 
386  /// Extract TransferWriteOps of vector<1x> into a separate warp op.
387  LogicalResult tryExtractOp(RewriterBase &rewriter,
388  vector::TransferWriteOp writeOp,
389  WarpExecuteOnLane0Op warpOp) const {
390  Location loc = writeOp.getLoc();
391  VectorType vecType = writeOp.getVectorType();
392 
393  // Only sink out vector of 1 element for now to not serialize large vector
394  // store. This can later be controlled by user.
395  if (vecType.getNumElements() != 1)
396  return failure();
397 
398  // Do not process warp ops that contain only TransferWriteOps.
399  if (llvm::all_of(warpOp.getOps(), [](Operation &op) {
400  return isa<vector::TransferWriteOp, vector::YieldOp>(&op);
401  }))
402  return failure();
403 
404  SmallVector<Value> yieldValues = {writeOp.getVector()};
405  SmallVector<Type> retTypes = {vecType};
406  SmallVector<size_t> newRetIndices;
407  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
408  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
409  rewriter.setInsertionPointAfter(newWarpOp);
410 
411  // Create a second warp op that contains only writeOp.
412  auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>(
413  loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize());
414  Block &body = secondWarpOp.getBodyRegion().front();
415  rewriter.setInsertionPointToStart(&body);
416  auto newWriteOp =
417  cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation()));
418  newWriteOp.getVectorMutable().assign(newWarpOp.getResult(newRetIndices[0]));
419  rewriter.eraseOp(writeOp);
420  rewriter.create<vector::YieldOp>(newWarpOp.getLoc());
421  return success();
422  }
423 
424  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
425  PatternRewriter &rewriter) const override {
426  // Ops with mask not supported yet.
427  if (writeOp.getMask())
428  return failure();
429 
430  auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp());
431  if (!warpOp)
432  return failure();
433 
434  // There must be no op with a side effect after writeOp.
435  Operation *nextOp = writeOp.getOperation();
436  while ((nextOp = nextOp->getNextNode()))
437  if (!isSideEffectFree(nextOp))
438  return failure();
439 
440  if (!llvm::all_of(writeOp->getOperands(), [&](Value value) {
441  return writeOp.getVector() == value ||
442  warpOp.isDefinedOutsideOfRegion(value);
443  }))
444  return failure();
445 
446  if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp)))
447  return success();
448 
449  if (succeeded(tryExtractOp(rewriter, writeOp, warpOp)))
450  return success();
451 
452  return failure();
453  }
454 
455 private:
456  DistributionMapFn distributionMapFn;
457 };
458 
459 /// Sink out elementwise op feeding into a warp op yield.
460 /// ```
461 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
462 /// ...
463 /// %3 = arith.addf %1, %2 : vector<32xf32>
464 /// vector.yield %3 : vector<32xf32>
465 /// }
466 /// ```
467 /// To
468 /// ```
469 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
470 /// vector<1xf32>, vector<1xf32>) {
471 /// ...
472 /// %4 = arith.addf %2, %3 : vector<32xf32>
473 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>,
474 /// vector<32xf32>
475 /// }
476 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32>
477 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> {
479  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
480  PatternRewriter &rewriter) const override {
481  OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) {
483  });
484  if (!yieldOperand)
485  return failure();
486  Operation *elementWise = yieldOperand->get().getDefiningOp();
487  unsigned operandIndex = yieldOperand->getOperandNumber();
488  Value distributedVal = warpOp.getResult(operandIndex);
489  SmallVector<Value> yieldValues;
490  SmallVector<Type> retTypes;
491  Location loc = warpOp.getLoc();
492  for (OpOperand &operand : elementWise->getOpOperands()) {
493  Type targetType;
494  if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) {
495  // If the result type is a vector, the operands must also be vectors.
496  auto operandType = operand.get().getType().cast<VectorType>();
497  targetType =
498  VectorType::get(vecType.getShape(), operandType.getElementType());
499  } else {
500  auto operandType = operand.get().getType();
501  assert(!operandType.isa<VectorType>() &&
502  "unexpected yield of vector from op with scalar result type");
503  targetType = operandType;
504  }
505  retTypes.push_back(targetType);
506  yieldValues.push_back(operand.get());
507  }
508  SmallVector<size_t> newRetIndices;
509  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
510  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
511  rewriter.setInsertionPointAfter(newWarpOp);
512  SmallVector<Value> newOperands(elementWise->getOperands().begin(),
513  elementWise->getOperands().end());
514  for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) {
515  newOperands[i] = newWarpOp.getResult(newRetIndices[i]);
516  }
517  OpBuilder::InsertionGuard g(rewriter);
518  rewriter.setInsertionPointAfter(newWarpOp);
520  rewriter, loc, elementWise, newOperands,
521  {newWarpOp.getResult(operandIndex).getType()});
522  newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0));
523  return success();
524  }
525 };
526 
527 /// Sink out splat constant op feeding into a warp op yield.
528 /// ```
529 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
530 /// ...
531 /// %cst = arith.constant dense<2.0> : vector<32xf32>
532 /// vector.yield %cst : vector<32xf32>
533 /// }
534 /// ```
535 /// To
536 /// ```
537 /// vector.warp_execute_on_lane_0(%arg0 {
538 /// ...
539 /// }
540 /// %0 = arith.constant dense<2.0> : vector<1xf32>
541 struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
543  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
544  PatternRewriter &rewriter) const override {
545  OpOperand *yieldOperand = getWarpResult(
546  warpOp, [](Operation *op) { return isa<arith::ConstantOp>(op); });
547  if (!yieldOperand)
548  return failure();
549  auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>();
550  auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
551  if (!dense)
552  return failure();
553  unsigned operandIndex = yieldOperand->getOperandNumber();
554  Attribute scalarAttr = dense.getSplatValue<Attribute>();
556  warpOp.getResult(operandIndex).getType(), scalarAttr);
557  Location loc = warpOp.getLoc();
558  rewriter.setInsertionPointAfter(warpOp);
559  Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
560  warpOp.getResult(operandIndex).replaceAllUsesWith(distConstant);
561  return success();
562  }
563 };
564 
565 /// Sink out transfer_read op feeding into a warp op yield.
566 /// ```
567 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) {
568 /// ...
569 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
570 // vector<32xf32>
571 /// vector.yield %2 : vector<32xf32>
572 /// }
573 /// ```
574 /// To
575 /// ```
576 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>,
577 /// vector<1xf32>, vector<1xf32>) {
578 /// ...
579 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>,
580 /// vector<32xf32> vector.yield %2 : vector<32xf32>
581 /// }
582 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32>
583 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> {
585  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
586  PatternRewriter &rewriter) const override {
587  OpOperand *operand = getWarpResult(
588  warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); });
589  if (!operand)
590  return failure();
591  auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
592  unsigned operandIndex = operand->getOperandNumber();
593  Value distributedVal = warpOp.getResult(operandIndex);
594 
595  SmallVector<Value, 4> indices(read.getIndices().begin(),
596  read.getIndices().end());
597  AffineMap map = calculateImplicitMap(read.getResult(), distributedVal);
598  AffineMap indexMap = map.compose(read.getPermutationMap());
599  OpBuilder::InsertionGuard g(rewriter);
600  rewriter.setInsertionPointAfter(warpOp);
601  for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) {
602  AffineExpr d0, d1;
603  bindDims(read.getContext(), d0, d1);
604  auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>();
605  if (!indexExpr)
606  continue;
607  unsigned indexPos = indexExpr.getPosition();
608  unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition();
609  int64_t scale =
610  distributedVal.getType().cast<VectorType>().getDimSize(vectorPos);
611  indices[indexPos] =
612  makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1,
613  {indices[indexPos], warpOp.getLaneid()});
614  }
615  Value newRead = rewriter.create<vector::TransferReadOp>(
616  read.getLoc(), distributedVal.getType(), read.getSource(), indices,
617  read.getPermutationMapAttr(), read.getPadding(), read.getMask(),
618  read.getInBoundsAttr());
619  distributedVal.replaceAllUsesWith(newRead);
620  return success();
621  }
622 };
623 
624 /// Remove any result that has no use along with the matching yieldOp operand.
625 // TODO: Move this in WarpExecuteOnLane0Op canonicalization.
626 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> {
628  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
629  PatternRewriter &rewriter) const override {
630  SmallVector<Type> resultTypes;
631  SmallVector<Value> yieldValues;
632  auto yield = cast<vector::YieldOp>(
633  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
634  for (OpResult result : warpOp.getResults()) {
635  if (result.use_empty())
636  continue;
637  resultTypes.push_back(result.getType());
638  yieldValues.push_back(yield.getOperand(result.getResultNumber()));
639  }
640  if (yield.getNumOperands() == yieldValues.size())
641  return failure();
642  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
643  rewriter, warpOp, yieldValues, resultTypes);
644  unsigned resultIndex = 0;
645  for (OpResult result : warpOp.getResults()) {
646  if (result.use_empty())
647  continue;
648  result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++));
649  }
650  rewriter.eraseOp(warpOp);
651  return success();
652  }
653 };
654 
655 // If an operand is directly yielded out of the region we can forward it
656 // directly and it doesn't need to go through the region.
657 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
659  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
660  PatternRewriter &rewriter) const override {
661  SmallVector<Type> resultTypes;
662  SmallVector<Value> yieldValues;
663  auto yield = cast<vector::YieldOp>(
664  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
665  Value valForwarded;
666  unsigned resultIndex;
667  for (OpOperand &operand : yield->getOpOperands()) {
668  Value result = warpOp.getResult(operand.getOperandNumber());
669  if (result.use_empty())
670  continue;
671 
672  // Assume all the values coming from above are uniform.
673  if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) {
674  if (result.getType() != operand.get().getType())
675  continue;
676  valForwarded = operand.get();
677  resultIndex = operand.getOperandNumber();
678  break;
679  }
680  auto arg = operand.get().dyn_cast<BlockArgument>();
681  if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation())
682  continue;
683  Value warpOperand = warpOp.getArgs()[arg.getArgNumber()];
684  if (result.getType() != warpOperand.getType())
685  continue;
686  valForwarded = warpOperand;
687  resultIndex = operand.getOperandNumber();
688  break;
689  }
690  if (!valForwarded)
691  return failure();
692  warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded);
693  return success();
694  }
695 };
696 
697 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> {
699  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
700  PatternRewriter &rewriter) const override {
701  OpOperand *operand = getWarpResult(
702  warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); });
703  if (!operand)
704  return failure();
705  unsigned int operandNumber = operand->getOperandNumber();
706  auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>();
707  Location loc = broadcastOp.getLoc();
708  auto destVecType =
709  warpOp->getResultTypes()[operandNumber].cast<VectorType>();
710  SmallVector<size_t> newRetIndices;
711  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
712  rewriter, warpOp, {broadcastOp.getSource()},
713  {broadcastOp.getSource().getType()}, newRetIndices);
714  rewriter.setInsertionPointAfter(newWarpOp);
715  Value broadcasted = rewriter.create<vector::BroadcastOp>(
716  loc, destVecType, newWarpOp->getResult(newRetIndices[0]));
717  newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted);
718  return success();
719  }
720 };
721 
722 /// Pattern to move out vector.extract of single element vector. Those don't
723 /// need to be distributed and can just be propagated outside of the region.
724 struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
726  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
727  PatternRewriter &rewriter) const override {
728  OpOperand *operand = getWarpResult(
729  warpOp, [](Operation *op) { return isa<vector::ExtractOp>(op); });
730  if (!operand)
731  return failure();
732  unsigned int operandNumber = operand->getOperandNumber();
733  auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
734  if (extractOp.getVectorType().getNumElements() != 1)
735  return failure();
736  Location loc = extractOp.getLoc();
737  SmallVector<size_t> newRetIndices;
738  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
739  rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()},
740  newRetIndices);
741  rewriter.setInsertionPointAfter(newWarpOp);
742  Value newExtract = rewriter.create<vector::ExtractOp>(
743  loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition());
744  newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
745  return success();
746  }
747 };
748 
749 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
750 /// the scf.ForOp is the last operation in the region so that it doesn't change
751 /// the order of execution. This creates a new scf.for region after the
752 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new
753 /// WarpExecuteOnLane0Op region. Example:
754 /// ```
755 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) {
756 /// ...
757 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v)
758 /// -> (vector<128xf32>) {
759 /// ...
760 /// scf.yield %r : vector<128xf32>
761 /// }
762 /// vector.yield %v1 : vector<128xf32>
763 /// }
764 /// ```
765 /// To:
766 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) {
767 /// ...
768 /// vector.yield %v : vector<128xf32>
769 /// }
770 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0)
771 /// -> (vector<4xf32>) {
772 /// %iw = vector.warp_execute_on_lane_0(%laneid)
773 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) {
774 /// ^bb0(%arg: vector<128xf32>):
775 /// ...
776 /// vector.yield %ir : vector<128xf32>
777 /// }
778 /// scf.yield %iw : vector<4xf32>
779 /// }
780 /// ```
781 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
783  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
784  PatternRewriter &rewriter) const override {
785  auto yield = cast<vector::YieldOp>(
786  warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
787  // Only pick up forOp if it is the last op in the region.
788  Operation *lastNode = yield->getPrevNode();
789  auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
790  if (!forOp)
791  return failure();
792  SmallVector<Value> newOperands;
793  SmallVector<unsigned> resultIdx;
794  // Collect all the outputs coming from the forOp.
795  for (OpOperand &yieldOperand : yield->getOpOperands()) {
796  if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
797  continue;
798  auto forResult = yieldOperand.get().cast<OpResult>();
799  newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber()));
800  yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]);
801  resultIdx.push_back(yieldOperand.getOperandNumber());
802  }
803  OpBuilder::InsertionGuard g(rewriter);
804  rewriter.setInsertionPointAfter(warpOp);
805  // Create a new for op outside the region with a WarpExecuteOnLane0Op region
806  // inside.
807  auto newForOp = rewriter.create<scf::ForOp>(
808  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
809  forOp.getStep(), newOperands);
810  rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
811  auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
812  warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(),
813  warpOp.getWarpSize(), newForOp.getRegionIterArgs(),
814  forOp.getResultTypes());
815 
816  SmallVector<Value> argMapping;
817  argMapping.push_back(newForOp.getInductionVar());
818  for (Value args : innerWarp.getBody()->getArguments()) {
819  argMapping.push_back(args);
820  }
821  SmallVector<Value> yieldOperands;
822  for (Value operand : forOp.getBody()->getTerminator()->getOperands())
823  yieldOperands.push_back(operand);
824  rewriter.eraseOp(forOp.getBody()->getTerminator());
825  rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
826  rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end());
827  rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands);
828  rewriter.setInsertionPointAfter(innerWarp);
829  if (!innerWarp.getResults().empty())
830  rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
831  rewriter.eraseOp(forOp);
832  // Replace the warpOp result coming from the original ForOp.
833  for (const auto &res : llvm::enumerate(resultIdx)) {
834  warpOp.getResult(res.value())
835  .replaceAllUsesWith(newForOp.getResult(res.index()));
836  newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value()));
837  }
838  return success();
839  }
840 };
841 
842 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
843 /// The vector is reduced in parallel. Currently limited to vector size matching
844 /// the warpOp size. E.g.:
845 /// ```
846 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
847 /// %0 = "some_def"() : () -> (vector<32xf32>)
848 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32
849 /// vector_ext.yield %1 : f32
850 /// }
851 /// ```
852 /// is lowered to:
853 /// ```
854 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
855 /// %1 = "some_def"() : () -> (vector<32xf32>)
856 /// vector_ext.yield %1 : vector<32xf32>
857 /// }
858 /// %a = vector.extract %0[0] : vector<1xf32>
859 /// %r = ("warp.reduction %a")
860 /// ```
861 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> {
862  WarpOpReduction(MLIRContext *context,
863  DistributedReductionFn distributedReductionFn,
864  PatternBenefit benefit = 1)
865  : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit),
866  distributedReductionFn(distributedReductionFn) {}
867 
868  LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
869  PatternRewriter &rewriter) const override {
870  OpOperand *yieldOperand = getWarpResult(
871  warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); });
872  if (!yieldOperand)
873  return failure();
874 
875  auto reductionOp =
876  cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp());
877  auto vectorType = reductionOp.getVector().getType().cast<VectorType>();
878  // Only rank 1 vectors supported.
879  if (vectorType.getRank() != 1)
880  return rewriter.notifyMatchFailure(
881  warpOp, "Only rank 1 reductions can be distributed.");
882  // Only warp_size-sized vectors supported.
883  if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0)
884  return rewriter.notifyMatchFailure(
885  warpOp, "Reduction vector dimension must match was size.");
886  // Only f32 and i32 element types are supported.
887  if (!reductionOp.getType().isF32() &&
888  !reductionOp.getType().isSignlessInteger(32))
889  return rewriter.notifyMatchFailure(
890  warpOp,
891  "Reduction distribution currently only supports 32bits types.");
892 
893  int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize();
894  // Return vector that will be reduced from the WarpExecuteOnLane0Op.
895  unsigned operandIndex = yieldOperand->getOperandNumber();
896  SmallVector<Value> yieldValues = {reductionOp.getVector()};
897  SmallVector<Type> retTypes = {
898  VectorType::get({numElements}, reductionOp.getType())};
899  if (reductionOp.getAcc()) {
900  yieldValues.push_back(reductionOp.getAcc());
901  retTypes.push_back(reductionOp.getAcc().getType());
902  }
903  SmallVector<size_t> newRetIndices;
904  WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
905  rewriter, warpOp, yieldValues, retTypes, newRetIndices);
906  rewriter.setInsertionPointAfter(newWarpOp);
907 
908  Value laneValVec = newWarpOp.getResult(newRetIndices[0]);
909  // First reduce on a single thread.
910  Value perLaneReduction = rewriter.create<vector::ReductionOp>(
911  reductionOp.getLoc(), reductionOp.getKind(), laneValVec);
912  // Then distribute across threads.
913  Value fullReduce =
914  distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction,
915  reductionOp.getKind(), newWarpOp.getWarpSize());
916  if (reductionOp.getAcc()) {
917  fullReduce = vector::makeArithReduction(
918  rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce,
919  newWarpOp.getResult(newRetIndices[1]));
920  }
921  newWarpOp.getResult(operandIndex).replaceAllUsesWith(fullReduce);
922  return success();
923  }
924 
925 private:
926  DistributedReductionFn distributedReductionFn;
927 };
928 
929 } // namespace
930 
932  RewritePatternSet &patterns,
934  patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options);
935 }
936 
937 void mlir::vector::populateDistributeTransferWriteOpPatterns(
938  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) {
939  patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn);
940 }
941 
942 void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
943  RewritePatternSet &patterns) {
944  patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult,
945  WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand,
946  WarpOpScfForOp, WarpOpConstant>(patterns.getContext());
947 }
948 
949 void mlir::vector::populateDistributeReduction(
950  RewritePatternSet &patterns,
951  DistributedReductionFn distributedReductionFn) {
952  patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn);
953 }
954 
955 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) {
956  Block *body = warpOp.getBody();
957 
958  // Keep track of the ops we want to hoist.
959  llvm::SmallSetVector<Operation *, 8> opsToMove;
960 
961  // Helper to check if a value is or will be defined outside of the region.
962  auto isDefinedOutsideOfBody = [&](Value value) {
963  auto *definingOp = value.getDefiningOp();
964  return (definingOp && opsToMove.count(definingOp)) ||
965  warpOp.isDefinedOutsideOfRegion(value);
966  };
967 
968  // Do not use walk here, as we do not want to go into nested regions and hoist
969  // operations from there.
970  for (auto &op : body->without_terminator()) {
971  bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) {
972  return result.getType().isa<VectorType>();
973  });
974  if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody))
975  opsToMove.insert(&op);
976  }
977 
978  // Move all the ops marked as uniform outside of the region.
979  for (Operation *op : opsToMove)
980  op->moveBefore(warpOp);
981 }
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options)
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
Definition: Operation.cpp:430
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
U cast() const
Definition: Attributes.h:135
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:439
static bool canBeHoisted(Operation *op, function_ref< bool(Value)> definedOutside)
Helper to know if an op can be hoisted out of the region.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static AffineMap calculateImplicitMap(Value yield, Value ret)
Currently the distribution map is implicit based on the vector shape.
BlockListType & getBlocks()
Definition: Region.h:45
This is a value defined by a result of an operation.
Definition: Value.h:425
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:344
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:492
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1129
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
unsigned getNumOperands()
Definition: Operation.h:263
std::function< AffineMap(vector::TransferWriteOp)> DistributionMapFn
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:798
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
Operation & front()
Definition: Block.h:144
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:200
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:83
void replaceAllUsesWith(Value newValue) const
Replace all uses of &#39;this&#39; value with the new value, updating anything in the IR that uses &#39;this&#39; to ...
Definition: Value.h:162
static constexpr const bool value
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, llvm::SmallVector< size_t > &indices)
Helper to create a new WarpExecuteOnLane0Op region with extra outputs.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:358
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:300
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool isSideEffectFree(Operation *op)
Returns true if the given operation is side-effect free.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
iterator begin()
Definition: Region.h:55
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
U dyn_cast() const
Definition: Types.h:270
Attributes are known-constant values of operations.
Definition: Attributes.h:24
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
Base type for affine expression.
Definition: AffineExpr.h:68
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
unsigned getNumResults() const
Definition: AffineMap.cpp:302
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This represents an operation in an abstracted form, suitable for use with the builder APIs...
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
This class represents an argument of a Block.
Definition: Value.h:300
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:307
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:315
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:203
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2)
Return the result value of reducing two scalar/vector values with the corresponding arith operation...
static OpOperand * getWarpResult(WarpExecuteOnLane0Op warpOp, std::function< bool(Operation *)> fn)
Return a value yielded by warpOp which statifies the filter lamdba condition and is not dead...
static llvm::ManagedStatic< PassManagerOptions > options
static int resultIndex(int i)
Definition: Operator.cpp:308
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:377
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:317
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:294
static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes)
Helper to create a new WarpExecuteOnLane0Op with different signature.
Type getType() const
Return the type of this value.
Definition: Value.h:118
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class represents an operand of an operation.
Definition: Value.h:251
type_range getType() const
Definition: ValueRange.cpp:39
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:382
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
static Operation * cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:328
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
bool isa() const
Definition: Types.h:254
result_range getResults()
Definition: Operation.h:332
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block &#39;source&#39; into the end of block &#39;dest&#39;.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
static LogicalResult rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, const WarpExecuteOnLane0LoweringOptions &options)
MLIRContext * getContext() const
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:121
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
U cast() const
Definition: Types.h:278
virtual void eraseBlock(Block *block)
This method erases all operations in a block.