MLIR  22.0.0git
AffineTransformOps.cpp
Go to the documentation of this file.
1 //=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 
20 using namespace mlir;
21 using namespace mlir::affine;
22 using namespace mlir::transform;
23 
24 //===----------------------------------------------------------------------===//
25 // SimplifyBoundedAffineOpsOp
26 //===----------------------------------------------------------------------===//
27 
28 LogicalResult SimplifyBoundedAffineOpsOp::verify() {
29  if (getLowerBounds().size() != getBoundedValues().size())
30  return emitOpError() << "incorrect number of lower bounds, expected "
31  << getBoundedValues().size() << " but found "
32  << getLowerBounds().size();
33  if (getUpperBounds().size() != getBoundedValues().size())
34  return emitOpError() << "incorrect number of upper bounds, expected "
35  << getBoundedValues().size() << " but found "
36  << getUpperBounds().size();
37  return success();
38 }
39 
40 namespace {
41 /// Simplify affine.min / affine.max ops with the given constraints. They are
42 /// either rewritten to affine.apply or left unchanged.
43 template <typename OpTy>
44 struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
46  SimplifyAffineMinMaxOp(MLIRContext *ctx,
47  const FlatAffineValueConstraints &constraints,
48  PatternBenefit benefit = 1)
49  : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
50 
51  LogicalResult matchAndRewrite(OpTy op,
52  PatternRewriter &rewriter) const override {
53  FailureOr<AffineValueMap> simplified =
54  simplifyConstrainedMinMaxOp(op, constraints);
55  if (failed(simplified))
56  return failure();
57  rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
58  simplified->getOperands());
59  return success();
60  }
61 
62  const FlatAffineValueConstraints &constraints;
63 };
64 } // namespace
65 
67 SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
68  TransformResults &results,
69  TransformState &state) {
70  // Get constraints for bounded values.
73  SmallVector<Value> boundedValues;
74  DenseSet<Operation *> boundedOps;
75  for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
76  getUpperBounds())) {
77  Value handle = std::get<0>(it);
78  for (Operation *op : state.getPayloadOps(handle)) {
79  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
80  auto diag =
82  << "expected bounded value handle to point to one or multiple "
83  "single-result index-typed ops";
84  diag.attachNote(op->getLoc()) << "multiple/non-index result";
85  return diag;
86  }
87  boundedValues.push_back(op->getResult(0));
88  boundedOps.insert(op);
89  lbs.push_back(std::get<1>(it));
90  ubs.push_back(std::get<2>(it));
91  }
92  }
93 
94  // Build constraint set.
96  for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) {
97  unsigned pos;
98  if (!cstr.findVar(std::get<0>(it), &pos))
99  pos = cstr.appendSymbolVar(std::get<0>(it));
100  cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it));
101  // Note: addBound bounds are inclusive, but specified UB is exclusive.
102  cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1);
103  }
104 
105  // Transform all targets.
106  SmallVector<Operation *> targets;
107  for (Operation *target : state.getPayloadOps(getTarget())) {
108  if (!isa<AffineMinOp, AffineMaxOp>(target)) {
109  auto diag = emitDefiniteFailure()
110  << "target must be affine.min or affine.max";
111  diag.attachNote(target->getLoc()) << "target op";
112  return diag;
113  }
114  if (boundedOps.contains(target)) {
115  auto diag = emitDefiniteFailure()
116  << "target op result must not be constrained";
117  diag.attachNote(target->getLoc()) << "target/constrained op";
118  return diag;
119  }
120  targets.push_back(target);
121  }
123  // Canonicalization patterns are needed so that affine.apply ops are composed
124  // with the remaining affine.min/max ops.
125  AffineMaxOp::getCanonicalizationPatterns(patterns, getContext());
126  AffineMinOp::getCanonicalizationPatterns(patterns, getContext());
127  patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
128  SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
129  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
130  // Apply the simplification pattern to a fixpoint.
132  targets, frozenPatterns,
134  .setListener(
135  static_cast<RewriterBase::Listener *>(rewriter.getListener()))
137  auto diag = emitDefiniteFailure()
138  << "affine.min/max simplification did not converge";
139  return diag;
140  }
142 }
143 
144 void SimplifyBoundedAffineOpsOp::getEffects(
146  consumesHandle(getTargetMutable(), effects);
147  for (OpOperand &operand : getBoundedValuesMutable())
148  onlyReadsHandle(operand, effects);
149  modifiesPayload(effects);
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // SimplifyMinMaxAffineOpsOp
154 //===----------------------------------------------------------------------===//
156 SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter,
157  TransformResults &results,
158  TransformState &state) {
159  SmallVector<Operation *> targets;
160  for (Operation *target : state.getPayloadOps(getTarget())) {
161  if (!isa<AffineMinOp, AffineMaxOp>(target)) {
162  auto diag = emitDefiniteFailure()
163  << "target must be affine.min or affine.max";
164  diag.attachNote(target->getLoc()) << "target op";
165  return diag;
166  }
167  targets.push_back(target);
168  }
169  bool modified = false;
170  if (failed(mlir::affine::simplifyAffineMinMaxOps(rewriter, targets,
171  &modified))) {
172  return emitDefiniteFailure()
173  << "affine.min/max simplification did not converge";
174  }
175  if (!modified) {
176  return emitSilenceableError()
177  << "the transform failed to simplify any of the target operations";
178  }
180 }
181 
182 void SimplifyMinMaxAffineOpsOp::getEffects(
184  consumesHandle(getTargetMutable(), effects);
185  modifiesPayload(effects);
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // Transform op registration
190 //===----------------------------------------------------------------------===//
191 
192 namespace {
193 class AffineTransformDialectExtension
195  AffineTransformDialectExtension> {
196 public:
197  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)
198 
199  using Base::Base;
200 
201  void init() {
202  declareGeneratedDialect<AffineDialect>();
203 
204  registerTransformOps<
205 #define GET_OP_LIST
206 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
207  >();
208  }
209 };
210 } // namespace
211 
212 #define GET_OP_CLASSES
213 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
214 
216  DialectRegistry &registry) {
217  registry.addExtensions<AffineTransformDialectExtension>();
218 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
bool findVar(Value val, unsigned *pos, unsigned offset=0) const
Looks up the position of the variable with the specified Value starting with variables at offset offs...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setStrictness(GreedyRewriteStrictness mode)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
LogicalResult addBound(presburger::BoundType type, unsigned pos, AffineMap boundMap, ValueRange operands)
Adds a bound for the variable at the specified position with constraints being drawn from the specifi...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
void registerTransformDialectExtension(DialectRegistry &registry)
LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter, ArrayRef< Operation * > ops, bool *modified=nullptr)
This transform applies simplifyAffineMinOp and simplifyAffineMaxOp to all the affine....
FailureOr< AffineValueMap > simplifyConstrainedMinMaxOp(Operation *op, FlatAffineValueConstraints constraints)
Try to simplify the given affine.min or affine.max op to an affine map with a single result and opera...
Definition: Utils.cpp:2190
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314