MLIR  21.0.0git
AffineTransformOps.cpp
Go to the documentation of this file.
1 //=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 
19 using namespace mlir;
20 using namespace mlir::affine;
21 using namespace mlir::transform;
22 
23 //===----------------------------------------------------------------------===//
24 // SimplifyBoundedAffineOpsOp
25 //===----------------------------------------------------------------------===//
26 
27 LogicalResult SimplifyBoundedAffineOpsOp::verify() {
28  if (getLowerBounds().size() != getBoundedValues().size())
29  return emitOpError() << "incorrect number of lower bounds, expected "
30  << getBoundedValues().size() << " but found "
31  << getLowerBounds().size();
32  if (getUpperBounds().size() != getBoundedValues().size())
33  return emitOpError() << "incorrect number of upper bounds, expected "
34  << getBoundedValues().size() << " but found "
35  << getUpperBounds().size();
36  return success();
37 }
38 
39 namespace {
40 /// Simplify affine.min / affine.max ops with the given constraints. They are
41 /// either rewritten to affine.apply or left unchanged.
42 template <typename OpTy>
43 struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
45  SimplifyAffineMinMaxOp(MLIRContext *ctx,
46  const FlatAffineValueConstraints &constraints,
47  PatternBenefit benefit = 1)
48  : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
49 
50  LogicalResult matchAndRewrite(OpTy op,
51  PatternRewriter &rewriter) const override {
52  FailureOr<AffineValueMap> simplified =
53  simplifyConstrainedMinMaxOp(op, constraints);
54  if (failed(simplified))
55  return failure();
56  rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
57  simplified->getOperands());
58  return success();
59  }
60 
61  const FlatAffineValueConstraints &constraints;
62 };
63 } // namespace
64 
66 SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
67  TransformResults &results,
68  TransformState &state) {
69  // Get constraints for bounded values.
72  SmallVector<Value> boundedValues;
73  DenseSet<Operation *> boundedOps;
74  for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
75  getUpperBounds())) {
76  Value handle = std::get<0>(it);
77  for (Operation *op : state.getPayloadOps(handle)) {
78  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
79  auto diag =
81  << "expected bounded value handle to point to one or multiple "
82  "single-result index-typed ops";
83  diag.attachNote(op->getLoc()) << "multiple/non-index result";
84  return diag;
85  }
86  boundedValues.push_back(op->getResult(0));
87  boundedOps.insert(op);
88  lbs.push_back(std::get<1>(it));
89  ubs.push_back(std::get<2>(it));
90  }
91  }
92 
93  // Build constraint set.
95  for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) {
96  unsigned pos;
97  if (!cstr.findVar(std::get<0>(it), &pos))
98  pos = cstr.appendSymbolVar(std::get<0>(it));
99  cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it));
100  // Note: addBound bounds are inclusive, but specified UB is exclusive.
101  cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1);
102  }
103 
104  // Transform all targets.
105  SmallVector<Operation *> targets;
106  for (Operation *target : state.getPayloadOps(getTarget())) {
107  if (!isa<AffineMinOp, AffineMaxOp>(target)) {
108  auto diag = emitDefiniteFailure()
109  << "target must be affine.min or affine.max";
110  diag.attachNote(target->getLoc()) << "target op";
111  return diag;
112  }
113  if (boundedOps.contains(target)) {
114  auto diag = emitDefiniteFailure()
115  << "target op result must not be constrainted";
116  diag.attachNote(target->getLoc()) << "target/constrained op";
117  return diag;
118  }
119  targets.push_back(target);
120  }
121  SmallVector<Operation *> transformed;
123  // Canonicalization patterns are needed so that affine.apply ops are composed
124  // with the remaining affine.min/max ops.
125  AffineMaxOp::getCanonicalizationPatterns(patterns, getContext());
126  AffineMinOp::getCanonicalizationPatterns(patterns, getContext());
127  patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
128  SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
129  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
130  // Apply the simplification pattern to a fixpoint.
131  if (failed(applyOpPatternsGreedily(
132  targets, frozenPatterns,
134  .setListener(
135  static_cast<RewriterBase::Listener *>(rewriter.getListener()))
137  auto diag = emitDefiniteFailure()
138  << "affine.min/max simplification did not converge";
139  return diag;
140  }
142 }
143 
144 void SimplifyBoundedAffineOpsOp::getEffects(
146  consumesHandle(getTargetMutable(), effects);
147  for (OpOperand &operand : getBoundedValuesMutable())
148  onlyReadsHandle(operand, effects);
149  modifiesPayload(effects);
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // Transform op registration
154 //===----------------------------------------------------------------------===//
155 
156 namespace {
157 class AffineTransformDialectExtension
159  AffineTransformDialectExtension> {
160 public:
161  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)
162 
163  using Base::Base;
164 
165  void init() {
166  declareGeneratedDialect<AffineDialect>();
167 
168  registerTransformOps<
169 #define GET_OP_LIST
170 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
171  >();
172  }
173 };
174 } // namespace
175 
176 #define GET_OP_CLASSES
177 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
178 
180  DialectRegistry &registry) {
181  registry.addExtensions<AffineTransformDialectExtension>();
182 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
bool findVar(Value val, unsigned *pos, unsigned offset=0) const
Looks up the position of the variable with the specified Value starting with variables at offset offs...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setStrictness(GreedyRewriteStrictness mode)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
This class represents an operand of an operation.
Definition: Value.h:243
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
LogicalResult addBound(presburger::BoundType type, unsigned pos, AffineMap boundMap, ValueRange operands)
Adds a bound for the variable at the specified position with constraints being drawn from the specifi...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
void registerTransformDialectExtension(DialectRegistry &registry)
FailureOr< AffineValueMap > simplifyConstrainedMinMaxOp(Operation *op, FlatAffineValueConstraints constraints)
Try to simplify the given affine.min or affine.max op to an affine map with a single result and opera...
Definition: Utils.cpp:2205
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314