MLIR  14.0.0git
AffineToStandard.cpp
Go to the documentation of this file.
1 //===- AffineToStandard.cpp - Lower affine constructs to primitives -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file lowers affine constructs (If and For statements, AffineApply
10 // operations) within a function into their standard If and For equivalent ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
16 #include "../PassDetail.h"
20 #include "mlir/Dialect/SCF/SCF.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/IntegerSet.h"
27 #include "mlir/IR/MLIRContext.h"
28 #include "mlir/Pass/Pass.h"
30 #include "mlir/Transforms/Passes.h"
31 
32 using namespace mlir;
33 using namespace mlir::vector;
34 
35 namespace {
36 /// Visit affine expressions recursively and build the sequence of operations
37 /// that correspond to it. Visitation functions return an Value of the
38 /// expression subtree they visited or `nullptr` on error.
39 class AffineApplyExpander
40  : public AffineExprVisitor<AffineApplyExpander, Value> {
41 public:
42  /// This internal class expects arguments to be non-null, checks must be
43  /// performed at the call site.
44  AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
45  ValueRange symbolValues, Location loc)
46  : builder(builder), dimValues(dimValues), symbolValues(symbolValues),
47  loc(loc) {}
48 
49  template <typename OpTy>
50  Value buildBinaryExpr(AffineBinaryOpExpr expr) {
51  auto lhs = visit(expr.getLHS());
52  auto rhs = visit(expr.getRHS());
53  if (!lhs || !rhs)
54  return nullptr;
55  auto op = builder.create<OpTy>(loc, lhs, rhs);
56  return op.getResult();
57  }
58 
59  Value visitAddExpr(AffineBinaryOpExpr expr) {
60  return buildBinaryExpr<arith::AddIOp>(expr);
61  }
62 
63  Value visitMulExpr(AffineBinaryOpExpr expr) {
64  return buildBinaryExpr<arith::MulIOp>(expr);
65  }
66 
67  /// Euclidean modulo operation: negative RHS is not allowed.
68  /// Remainder of the euclidean integer division is always non-negative.
69  ///
70  /// Implemented as
71  ///
72  /// a mod b =
73  /// let remainder = srem a, b;
74  /// negative = a < 0 in
75  /// select negative, remainder + b, remainder.
76  Value visitModExpr(AffineBinaryOpExpr expr) {
77  auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
78  if (!rhsConst) {
79  emitError(
80  loc,
81  "semi-affine expressions (modulo by non-const) are not supported");
82  return nullptr;
83  }
84  if (rhsConst.getValue() <= 0) {
85  emitError(loc, "modulo by non-positive value is not supported");
86  return nullptr;
87  }
88 
89  auto lhs = visit(expr.getLHS());
90  auto rhs = visit(expr.getRHS());
91  assert(lhs && rhs && "unexpected affine expr lowering failure");
92 
93  Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs);
94  Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
95  Value isRemainderNegative = builder.create<arith::CmpIOp>(
96  loc, arith::CmpIPredicate::slt, remainder, zeroCst);
97  Value correctedRemainder =
98  builder.create<arith::AddIOp>(loc, remainder, rhs);
99  Value result = builder.create<SelectOp>(loc, isRemainderNegative,
100  correctedRemainder, remainder);
101  return result;
102  }
103 
104  /// Floor division operation (rounds towards negative infinity).
105  ///
106  /// For positive divisors, it can be implemented without branching and with a
107  /// single division operation as
108  ///
109  /// a floordiv b =
110  /// let negative = a < 0 in
111  /// let absolute = negative ? -a - 1 : a in
112  /// let quotient = absolute / b in
113  /// negative ? -quotient - 1 : quotient
114  Value visitFloorDivExpr(AffineBinaryOpExpr expr) {
115  auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
116  if (!rhsConst) {
117  emitError(
118  loc,
119  "semi-affine expressions (division by non-const) are not supported");
120  return nullptr;
121  }
122  if (rhsConst.getValue() <= 0) {
123  emitError(loc, "division by non-positive value is not supported");
124  return nullptr;
125  }
126 
127  auto lhs = visit(expr.getLHS());
128  auto rhs = visit(expr.getRHS());
129  assert(lhs && rhs && "unexpected affine expr lowering failure");
130 
131  Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
132  Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1);
133  Value negative = builder.create<arith::CmpIOp>(
134  loc, arith::CmpIPredicate::slt, lhs, zeroCst);
135  Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs);
136  Value dividend =
137  builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
138  Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
139  Value correctedQuotient =
140  builder.create<arith::SubIOp>(loc, noneCst, quotient);
141  Value result =
142  builder.create<SelectOp>(loc, negative, correctedQuotient, quotient);
143  return result;
144  }
145 
146  /// Ceiling division operation (rounds towards positive infinity).
147  ///
148  /// For positive divisors, it can be implemented without branching and with a
149  /// single division operation as
150  ///
151  /// a ceildiv b =
152  /// let negative = a <= 0 in
153  /// let absolute = negative ? -a : a - 1 in
154  /// let quotient = absolute / b in
155  /// negative ? -quotient : quotient + 1
156  Value visitCeilDivExpr(AffineBinaryOpExpr expr) {
157  auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
158  if (!rhsConst) {
159  emitError(loc) << "semi-affine expressions (division by non-const) are "
160  "not supported";
161  return nullptr;
162  }
163  if (rhsConst.getValue() <= 0) {
164  emitError(loc, "division by non-positive value is not supported");
165  return nullptr;
166  }
167  auto lhs = visit(expr.getLHS());
168  auto rhs = visit(expr.getRHS());
169  assert(lhs && rhs && "unexpected affine expr lowering failure");
170 
171  Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0);
172  Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1);
173  Value nonPositive = builder.create<arith::CmpIOp>(
174  loc, arith::CmpIPredicate::sle, lhs, zeroCst);
175  Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs);
176  Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst);
177  Value dividend =
178  builder.create<SelectOp>(loc, nonPositive, negated, decremented);
179  Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
180  Value negatedQuotient =
181  builder.create<arith::SubIOp>(loc, zeroCst, quotient);
182  Value incrementedQuotient =
183  builder.create<arith::AddIOp>(loc, quotient, oneCst);
184  Value result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient,
185  incrementedQuotient);
186  return result;
187  }
188 
189  Value visitConstantExpr(AffineConstantExpr expr) {
190  auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
191  return op.getResult();
192  }
193 
194  Value visitDimExpr(AffineDimExpr expr) {
195  assert(expr.getPosition() < dimValues.size() &&
196  "affine dim position out of range");
197  return dimValues[expr.getPosition()];
198  }
199 
200  Value visitSymbolExpr(AffineSymbolExpr expr) {
201  assert(expr.getPosition() < symbolValues.size() &&
202  "symbol dim position out of range");
203  return symbolValues[expr.getPosition()];
204  }
205 
206 private:
207  OpBuilder &builder;
208  ValueRange dimValues;
209  ValueRange symbolValues;
210 
211  Location loc;
212 };
213 } // namespace
214 
215 /// Create a sequence of operations that implement the `expr` applied to the
216 /// given dimension and symbol values.
218  AffineExpr expr, ValueRange dimValues,
219  ValueRange symbolValues) {
220  return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
221 }
222 
223 /// Create a sequence of operations that implement the `affineMap` applied to
224 /// the given `operands` (as it it were an AffineApplyOp).
226  Location loc,
227  AffineMap affineMap,
228  ValueRange operands) {
229  auto numDims = affineMap.getNumDims();
230  auto expanded = llvm::to_vector<8>(
231  llvm::map_range(affineMap.getResults(),
232  [numDims, &builder, loc, operands](AffineExpr expr) {
233  return expandAffineExpr(builder, loc, expr,
234  operands.take_front(numDims),
235  operands.drop_front(numDims));
236  }));
237  if (llvm::all_of(expanded, [](Value v) { return v; }))
238  return expanded;
239  return None;
240 }
241 
242 /// Given a range of values, emit the code that reduces them with "min" or "max"
243 /// depending on the provided comparison predicate. The predicate defines which
244 /// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
245 /// `cmpi` operation followed by the `select` operation:
246 ///
247 /// %cond = arith.cmpi "predicate" %v0, %v1
248 /// %result = select %cond, %v0, %v1
249 ///
250 /// Multiple values are scanned in a linear sequence. This creates a data
251 /// dependences that wouldn't exist in a tree reduction, but is easier to
252 /// recognize as a reduction by the subsequent passes.
254  arith::CmpIPredicate predicate,
255  ValueRange values, OpBuilder &builder) {
256  assert(!llvm::empty(values) && "empty min/max chain");
257 
258  auto valueIt = values.begin();
259  Value value = *valueIt++;
260  for (; valueIt != values.end(); ++valueIt) {
261  auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt);
262  value = builder.create<SelectOp>(loc, cmpOp.getResult(), value, *valueIt);
263  }
264 
265  return value;
266 }
267 
268 /// Emit instructions that correspond to computing the maximum value among the
269 /// values of a (potentially) multi-output affine map applied to `operands`.
271  ValueRange operands) {
272  if (auto values = expandAffineMap(builder, loc, map, operands))
273  return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::sgt, *values,
274  builder);
275  return nullptr;
276 }
277 
278 /// Emit instructions that correspond to computing the minimum value among the
279 /// values of a (potentially) multi-output affine map applied to `operands`.
281  ValueRange operands) {
282  if (auto values = expandAffineMap(builder, loc, map, operands))
283  return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::slt, *values,
284  builder);
285  return nullptr;
286 }
287 
288 /// Emit instructions that correspond to the affine map in the upper bound
289 /// applied to the respective operands, and compute the minimum value across
290 /// the results.
291 Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
292  return lowerAffineMapMin(builder, op.getLoc(), op.getUpperBoundMap(),
293  op.getUpperBoundOperands());
294 }
295 
296 /// Emit instructions that correspond to the affine map in the lower bound
297 /// applied to the respective operands, and compute the maximum value across
298 /// the results.
299 Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
300  return lowerAffineMapMax(builder, op.getLoc(), op.getLowerBoundMap(),
301  op.getLowerBoundOperands());
302 }
303 
304 namespace {
305 class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
306 public:
308 
309  LogicalResult matchAndRewrite(AffineMinOp op,
310  PatternRewriter &rewriter) const override {
311  Value reduced =
312  lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands());
313  if (!reduced)
314  return failure();
315 
316  rewriter.replaceOp(op, reduced);
317  return success();
318  }
319 };
320 
321 class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> {
322 public:
324 
325  LogicalResult matchAndRewrite(AffineMaxOp op,
326  PatternRewriter &rewriter) const override {
327  Value reduced =
328  lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands());
329  if (!reduced)
330  return failure();
331 
332  rewriter.replaceOp(op, reduced);
333  return success();
334  }
335 };
336 
337 /// Affine yields ops are removed.
338 class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
339 public:
341 
342  LogicalResult matchAndRewrite(AffineYieldOp op,
343  PatternRewriter &rewriter) const override {
344  if (isa<scf::ParallelOp>(op->getParentOp())) {
345  // scf.parallel does not yield any values via its terminator scf.yield but
346  // models reductions differently using additional ops in its region.
347  rewriter.replaceOpWithNewOp<scf::YieldOp>(op);
348  return success();
349  }
350  rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.operands());
351  return success();
352  }
353 };
354 
355 class AffineForLowering : public OpRewritePattern<AffineForOp> {
356 public:
358 
359  LogicalResult matchAndRewrite(AffineForOp op,
360  PatternRewriter &rewriter) const override {
361  Location loc = op.getLoc();
362  Value lowerBound = lowerAffineLowerBound(op, rewriter);
363  Value upperBound = lowerAffineUpperBound(op, rewriter);
364  Value step = rewriter.create<arith::ConstantIndexOp>(loc, op.getStep());
365  auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
366  step, op.getIterOperands());
367  rewriter.eraseBlock(scfForOp.getBody());
368  rewriter.inlineRegionBefore(op.region(), scfForOp.getRegion(),
369  scfForOp.getRegion().end());
370  rewriter.replaceOp(op, scfForOp.getResults());
371  return success();
372  }
373 };
374 
375 /// Convert an `affine.parallel` (loop nest) operation into a `scf.parallel`
376 /// operation.
377 class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
378 public:
380 
381  LogicalResult matchAndRewrite(AffineParallelOp op,
382  PatternRewriter &rewriter) const override {
383  Location loc = op.getLoc();
384  SmallVector<Value, 8> steps;
385  SmallVector<Value, 8> upperBoundTuple;
386  SmallVector<Value, 8> lowerBoundTuple;
387  SmallVector<Value, 8> identityVals;
388  // Emit IR computing the lower and upper bound by expanding the map
389  // expression.
390  lowerBoundTuple.reserve(op.getNumDims());
391  upperBoundTuple.reserve(op.getNumDims());
392  for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
393  Value lower = lowerAffineMapMax(rewriter, loc, op.getLowerBoundMap(i),
394  op.getLowerBoundsOperands());
395  if (!lower)
396  return rewriter.notifyMatchFailure(op, "couldn't convert lower bounds");
397  lowerBoundTuple.push_back(lower);
398 
399  Value upper = lowerAffineMapMin(rewriter, loc, op.getUpperBoundMap(i),
400  op.getUpperBoundsOperands());
401  if (!upper)
402  return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
403  upperBoundTuple.push_back(upper);
404  }
405  steps.reserve(op.steps().size());
406  for (Attribute step : op.steps())
407  steps.push_back(rewriter.create<arith::ConstantIndexOp>(
408  loc, step.cast<IntegerAttr>().getInt()));
409 
410  // Get the terminator op.
411  Operation *affineParOpTerminator = op.getBody()->getTerminator();
412  scf::ParallelOp parOp;
413  if (op.results().empty()) {
414  // Case with no reduction operations/return values.
415  parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
416  upperBoundTuple, steps,
417  /*bodyBuilderFn=*/nullptr);
418  rewriter.eraseBlock(parOp.getBody());
419  rewriter.inlineRegionBefore(op.region(), parOp.getRegion(),
420  parOp.getRegion().end());
421  rewriter.replaceOp(op, parOp.getResults());
422  return success();
423  }
424  // Case with affine.parallel with reduction operations/return values.
425  // scf.parallel handles the reduction operation differently unlike
426  // affine.parallel.
427  ArrayRef<Attribute> reductions = op.reductions().getValue();
428  for (auto pair : llvm::zip(reductions, op.getResultTypes())) {
429  // For each of the reduction operations get the identity values for
430  // initialization of the result values.
431  Attribute reduction = std::get<0>(pair);
432  Type resultType = std::get<1>(pair);
433  Optional<arith::AtomicRMWKind> reductionOp =
434  arith::symbolizeAtomicRMWKind(
435  static_cast<uint64_t>(reduction.cast<IntegerAttr>().getInt()));
436  assert(reductionOp.hasValue() &&
437  "Reduction operation cannot be of None Type");
438  arith::AtomicRMWKind reductionOpValue = reductionOp.getValue();
439  identityVals.push_back(
440  arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
441  }
442  parOp = rewriter.create<scf::ParallelOp>(
443  loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
444  /*bodyBuilderFn=*/nullptr);
445 
446  // Copy the body of the affine.parallel op.
447  rewriter.eraseBlock(parOp.getBody());
448  rewriter.inlineRegionBefore(op.region(), parOp.getRegion(),
449  parOp.getRegion().end());
450  assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
451  "Unequal number of reductions and operands.");
452  for (unsigned i = 0, end = reductions.size(); i < end; i++) {
453  // For each of the reduction operations get the respective mlir::Value.
454  Optional<arith::AtomicRMWKind> reductionOp =
455  arith::symbolizeAtomicRMWKind(
456  reductions[i].cast<IntegerAttr>().getInt());
457  assert(reductionOp.hasValue() &&
458  "Reduction Operation cannot be of None Type");
459  arith::AtomicRMWKind reductionOpValue = reductionOp.getValue();
460  rewriter.setInsertionPoint(&parOp.getBody()->back());
461  auto reduceOp = rewriter.create<scf::ReduceOp>(
462  loc, affineParOpTerminator->getOperand(i));
463  rewriter.setInsertionPointToEnd(&reduceOp.getReductionOperator().front());
464  Value reductionResult = arith::getReductionOp(
465  reductionOpValue, rewriter, loc,
466  reduceOp.getReductionOperator().front().getArgument(0),
467  reduceOp.getReductionOperator().front().getArgument(1));
468  rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
469  }
470  rewriter.replaceOp(op, parOp.getResults());
471  return success();
472  }
473 };
474 
475 class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
476 public:
478 
479  LogicalResult matchAndRewrite(AffineIfOp op,
480  PatternRewriter &rewriter) const override {
481  auto loc = op.getLoc();
482 
483  // Now we just have to handle the condition logic.
484  auto integerSet = op.getIntegerSet();
485  Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0);
486  SmallVector<Value, 8> operands(op.getOperands());
487  auto operandsRef = llvm::makeArrayRef(operands);
488 
489  // Calculate cond as a conjunction without short-circuiting.
490  Value cond = nullptr;
491  for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
492  AffineExpr constraintExpr = integerSet.getConstraint(i);
493  bool isEquality = integerSet.isEq(i);
494 
495  // Build and apply an affine expression
496  auto numDims = integerSet.getNumDims();
497  Value affResult = expandAffineExpr(rewriter, loc, constraintExpr,
498  operandsRef.take_front(numDims),
499  operandsRef.drop_front(numDims));
500  if (!affResult)
501  return failure();
502  auto pred =
503  isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
504  Value cmpVal =
505  rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
506  cond = cond
507  ? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult()
508  : cmpVal;
509  }
510  cond = cond ? cond
511  : rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
512  /*width=*/1);
513 
514  bool hasElseRegion = !op.elseRegion().empty();
515  auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
516  hasElseRegion);
517  rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.getThenRegion().back());
518  rewriter.eraseBlock(&ifOp.getThenRegion().back());
519  if (hasElseRegion) {
520  rewriter.inlineRegionBefore(op.elseRegion(),
521  &ifOp.getElseRegion().back());
522  rewriter.eraseBlock(&ifOp.getElseRegion().back());
523  }
524 
525  // Replace the Affine IfOp finally.
526  rewriter.replaceOp(op, ifOp.getResults());
527  return success();
528  }
529 };
530 
531 /// Convert an "affine.apply" operation into a sequence of arithmetic
532 /// operations using the StandardOps dialect.
533 class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
534 public:
536 
537  LogicalResult matchAndRewrite(AffineApplyOp op,
538  PatternRewriter &rewriter) const override {
539  auto maybeExpandedMap =
540  expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
541  llvm::to_vector<8>(op.getOperands()));
542  if (!maybeExpandedMap)
543  return failure();
544  rewriter.replaceOp(op, *maybeExpandedMap);
545  return success();
546  }
547 };
548 
549 /// Apply the affine map from an 'affine.load' operation to its operands, and
550 /// feed the results to a newly created 'memref.load' operation (which replaces
551 /// the original 'affine.load').
552 class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
553 public:
555 
556  LogicalResult matchAndRewrite(AffineLoadOp op,
557  PatternRewriter &rewriter) const override {
558  // Expand affine map from 'affineLoadOp'.
559  SmallVector<Value, 8> indices(op.getMapOperands());
560  auto resultOperands =
561  expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
562  if (!resultOperands)
563  return failure();
564 
565  // Build vector.load memref[expandedMap.results].
566  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, op.getMemRef(),
567  *resultOperands);
568  return success();
569  }
570 };
571 
572 /// Apply the affine map from an 'affine.prefetch' operation to its operands,
573 /// and feed the results to a newly created 'memref.prefetch' operation (which
574 /// replaces the original 'affine.prefetch').
575 class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
576 public:
578 
579  LogicalResult matchAndRewrite(AffinePrefetchOp op,
580  PatternRewriter &rewriter) const override {
581  // Expand affine map from 'affinePrefetchOp'.
582  SmallVector<Value, 8> indices(op.getMapOperands());
583  auto resultOperands =
584  expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
585  if (!resultOperands)
586  return failure();
587 
588  // Build memref.prefetch memref[expandedMap.results].
589  rewriter.replaceOpWithNewOp<memref::PrefetchOp>(
590  op, op.memref(), *resultOperands, op.isWrite(), op.localityHint(),
591  op.isDataCache());
592  return success();
593  }
594 };
595 
596 /// Apply the affine map from an 'affine.store' operation to its operands, and
597 /// feed the results to a newly created 'memref.store' operation (which replaces
598 /// the original 'affine.store').
599 class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
600 public:
602 
603  LogicalResult matchAndRewrite(AffineStoreOp op,
604  PatternRewriter &rewriter) const override {
605  // Expand affine map from 'affineStoreOp'.
606  SmallVector<Value, 8> indices(op.getMapOperands());
607  auto maybeExpandedMap =
608  expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
609  if (!maybeExpandedMap)
610  return failure();
611 
612  // Build memref.store valueToStore, memref[expandedMap.results].
613  rewriter.replaceOpWithNewOp<memref::StoreOp>(
614  op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
615  return success();
616  }
617 };
618 
619 /// Apply the affine maps from an 'affine.dma_start' operation to each of their
620 /// respective map operands, and feed the results to a newly created
621 /// 'memref.dma_start' operation (which replaces the original
622 /// 'affine.dma_start').
623 class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
624 public:
626 
627  LogicalResult matchAndRewrite(AffineDmaStartOp op,
628  PatternRewriter &rewriter) const override {
629  SmallVector<Value, 8> operands(op.getOperands());
630  auto operandsRef = llvm::makeArrayRef(operands);
631 
632  // Expand affine map for DMA source memref.
633  auto maybeExpandedSrcMap = expandAffineMap(
634  rewriter, op.getLoc(), op.getSrcMap(),
635  operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
636  if (!maybeExpandedSrcMap)
637  return failure();
638  // Expand affine map for DMA destination memref.
639  auto maybeExpandedDstMap = expandAffineMap(
640  rewriter, op.getLoc(), op.getDstMap(),
641  operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
642  if (!maybeExpandedDstMap)
643  return failure();
644  // Expand affine map for DMA tag memref.
645  auto maybeExpandedTagMap = expandAffineMap(
646  rewriter, op.getLoc(), op.getTagMap(),
647  operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
648  if (!maybeExpandedTagMap)
649  return failure();
650 
651  // Build memref.dma_start operation with affine map results.
652  rewriter.replaceOpWithNewOp<memref::DmaStartOp>(
653  op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
654  *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
655  *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
656  return success();
657  }
658 };
659 
660 /// Apply the affine map from an 'affine.dma_wait' operation tag memref,
661 /// and feed the results to a newly created 'memref.dma_wait' operation (which
662 /// replaces the original 'affine.dma_wait').
663 class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
664 public:
666 
667  LogicalResult matchAndRewrite(AffineDmaWaitOp op,
668  PatternRewriter &rewriter) const override {
669  // Expand affine map for DMA tag memref.
670  SmallVector<Value, 8> indices(op.getTagIndices());
671  auto maybeExpandedTagMap =
672  expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
673  if (!maybeExpandedTagMap)
674  return failure();
675 
676  // Build memref.dma_wait operation with affine map results.
677  rewriter.replaceOpWithNewOp<memref::DmaWaitOp>(
678  op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
679  return success();
680  }
681 };
682 
683 /// Apply the affine map from an 'affine.vector_load' operation to its operands,
684 /// and feed the results to a newly created 'vector.load' operation (which
685 /// replaces the original 'affine.vector_load').
686 class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
687 public:
689 
690  LogicalResult matchAndRewrite(AffineVectorLoadOp op,
691  PatternRewriter &rewriter) const override {
692  // Expand affine map from 'affineVectorLoadOp'.
693  SmallVector<Value, 8> indices(op.getMapOperands());
694  auto resultOperands =
695  expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
696  if (!resultOperands)
697  return failure();
698 
699  // Build vector.load memref[expandedMap.results].
700  rewriter.replaceOpWithNewOp<vector::LoadOp>(
701  op, op.getVectorType(), op.getMemRef(), *resultOperands);
702  return success();
703  }
704 };
705 
706 /// Apply the affine map from an 'affine.vector_store' operation to its
707 /// operands, and feed the results to a newly created 'vector.store' operation
708 /// (which replaces the original 'affine.vector_store').
709 class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
710 public:
712 
713  LogicalResult matchAndRewrite(AffineVectorStoreOp op,
714  PatternRewriter &rewriter) const override {
715  // Expand affine map from 'affineVectorStoreOp'.
716  SmallVector<Value, 8> indices(op.getMapOperands());
717  auto maybeExpandedMap =
718  expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
719  if (!maybeExpandedMap)
720  return failure();
721 
722  rewriter.replaceOpWithNewOp<vector::StoreOp>(
723  op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
724  return success();
725  }
726 };
727 
728 } // namespace
729 
731  // clang-format off
732  patterns.add<
733  AffineApplyLowering,
734  AffineDmaStartLowering,
735  AffineDmaWaitLowering,
736  AffineLoadLowering,
737  AffineMinLowering,
738  AffineMaxLowering,
739  AffineParallelLowering,
740  AffinePrefetchLowering,
741  AffineStoreLowering,
742  AffineForLowering,
743  AffineIfLowering,
744  AffineYieldOpLowering>(patterns.getContext());
745  // clang-format on
746 }
747 
749  RewritePatternSet &patterns) {
750  // clang-format off
751  patterns.add<
752  AffineVectorLoadLowering,
753  AffineVectorStoreLowering>(patterns.getContext());
754  // clang-format on
755 }
756 
757 namespace {
758 class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
759  void runOnOperation() override {
760  RewritePatternSet patterns(&getContext());
763  ConversionTarget target(getContext());
764  target
765  .addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
766  scf::SCFDialect, StandardOpsDialect, VectorDialect>();
767  if (failed(applyPartialConversion(getOperation(), target,
768  std::move(patterns))))
769  signalPassFailure();
770  }
771 };
772 } // namespace
773 
774 /// Lowers If and For operations within a function into their lower level CFG
775 /// equivalent blocks.
776 std::unique_ptr<Pass> mlir::createLowerAffinePass() {
777  return std::make_unique<LowerAffinePass>();
778 }
Affine binary operation expression.
Definition: AffineExpr.h:207
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
U cast() const
Definition: Location.h:67
Optional< SmallVector< Value, 8 > > expandAffineMap(OpBuilder &builder, Location loc, AffineMap affineMap, ValueRange operands)
Create a sequence of operations that implement the affineMap applied to the given operands (as it it ...
U cast() const
Definition: Attributes.h:123
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
void populateAffineToStdConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the Affine dialect to the Standard dialect, in particular convert structured affine control flow into CFG branch-based control flow.
unsigned getNumDims() const
Definition: AffineMap.cpp:294
Specialization of arith.constant op that returns an integer value.
Definition: Arithmetic.h:41
int64_t getValue() const
Definition: AffineExpr.cpp:508
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:329
Value getOperand(unsigned idx)
Definition: Operation.h:219
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumOperands()
Definition: Operation.h:215
unsigned getPosition() const
Definition: AffineExpr.cpp:312
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
AffineExpr getRHS() const
Definition: AffineExpr.cpp:307
Base class for AffineExpr visitors/walkers.
Value getTagMemRef()
Returns the Tag MemRef associated with the DMA operation being waited on.
Definition: AffineOps.h:280
U dyn_cast() const
Definition: AffineExpr.h:281
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
AffineExpr getLHS() const
Definition: AffineExpr.cpp:304
operand_range getTagIndices()
Returns the tag memref index for this DMA operation.
Definition: AffineOps.h:292
unsigned getSrcMemRefOperandIndex()
Returns the operand index of the source memref.
Definition: AffineOps.h:89
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Base type for affine expression.
Definition: AffineExpr.h:68
std::unique_ptr< Pass > createLowerAffinePass()
Lowers affine control flow operations (ForStmt, IfStmt and AffineApplyOp) to equivalent lower-level c...
Value lowerAffineLowerBound(AffineForOp op, OpBuilder &builder)
Emit code that computes the lower bound of the given affine loop using standard arithmetic operations...
static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands)
Emit instructions that correspond to computing the maximum value among the values of a (potentially) ...
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element &#39;tag[i...
Definition: AffineOps.h:266
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc)
Returns the identity value associated with an AtomicRMWKind op.
Value getSrcMemRef()
Returns the source MemRefType for this DMA operation.
Definition: AffineOps.h:92
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:311
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
unsigned getPosition() const
Definition: AffineExpr.cpp:497
static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands)
Emit instructions that correspond to computing the minimum value among the values of a (potentially) ...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
static Value buildMinMaxReductionSeq(Location loc, arith::CmpIPredicate predicate, ValueRange values, OpBuilder &builder)
Given a range of values, emit the code that reduces them with "min" or "max" depending on the provide...
AffineMap getTagMap()
Returns the affine map used to access the tag memref.
Definition: AffineOps.h:286
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
AffineMap getTagMap()
Returns the affine map used to access the tag memref.
Definition: AffineOps.h:169
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap getSrcMap()
Returns the affine map used to access the source memref.
Definition: AffineOps.h:101
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:124
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
Do not split vector transfer operations.
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:78
AffineMap getDstMap()
Returns the affine map used to access the destination memref.
Definition: AffineOps.h:140
unsigned getTagMemRefOperandIndex()
Returns the operand index of the tag memref.
Definition: AffineOps.h:153
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:74
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:367
This class describes a specific conversion target.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:802
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
static void visit(Operation *op, DenseSet< Operation *> &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation...
Definition: PDL.cpp:62
This class helps build Operations.
Definition: Builders.h:177
Value lowerAffineUpperBound(AffineForOp op, OpBuilder &builder)
Emit code that computes the upper bound of the given affine loop using standard arithmetic operations...
This class provides an abstraction over the different types of ranges over Values.
void populateAffineToVectorConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert vector-related Affine ops to the Vector dialect.
MLIRContext * getContext() const
Definition: PatternMatch.h:906
unsigned getDstMemRefOperandIndex()
Returns the operand index of the destination memref.
Definition: AffineOps.h:119
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:224
virtual void eraseBlock(Block *block)
This method erases all operations in a block.