MLIR  19.0.0git
AffineTransformOps.cpp
Go to the documentation of this file.
1 //=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 
19 using namespace mlir;
20 using namespace mlir::affine;
21 using namespace mlir::transform;
22 
23 //===----------------------------------------------------------------------===//
24 // SimplifyBoundedAffineOpsOp
25 //===----------------------------------------------------------------------===//
26 
28  if (getLowerBounds().size() != getBoundedValues().size())
29  return emitOpError() << "incorrect number of lower bounds, expected "
30  << getBoundedValues().size() << " but found "
31  << getLowerBounds().size();
32  if (getUpperBounds().size() != getBoundedValues().size())
33  return emitOpError() << "incorrect number of upper bounds, expected "
34  << getBoundedValues().size() << " but found "
35  << getUpperBounds().size();
36  return success();
37 }
38 
39 namespace {
40 /// Simplify affine.min / affine.max ops with the given constraints. They are
41 /// either rewritten to affine.apply or left unchanged.
42 template <typename OpTy>
43 struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
45  SimplifyAffineMinMaxOp(MLIRContext *ctx,
46  const FlatAffineValueConstraints &constraints,
47  PatternBenefit benefit = 1)
48  : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
49 
50  LogicalResult matchAndRewrite(OpTy op,
51  PatternRewriter &rewriter) const override {
52  FailureOr<AffineValueMap> simplified =
53  simplifyConstrainedMinMaxOp(op, constraints);
54  if (failed(simplified))
55  return failure();
56  rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
57  simplified->getOperands());
58  return success();
59  }
60 
61  const FlatAffineValueConstraints &constraints;
62 };
63 } // namespace
64 
66 SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
67  TransformResults &results,
68  TransformState &state) {
69  // Get constraints for bounded values.
72  SmallVector<Value> boundedValues;
73  DenseSet<Operation *> boundedOps;
74  for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
75  getUpperBounds())) {
76  Value handle = std::get<0>(it);
77  for (Operation *op : state.getPayloadOps(handle)) {
78  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
79  auto diag =
81  << "expected bounded value handle to point to one or multiple "
82  "single-result index-typed ops";
83  diag.attachNote(op->getLoc()) << "multiple/non-index result";
84  return diag;
85  }
86  boundedValues.push_back(op->getResult(0));
87  boundedOps.insert(op);
88  lbs.push_back(std::get<1>(it));
89  ubs.push_back(std::get<2>(it));
90  }
91  }
92 
93  // Build constraint set.
95  for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) {
96  unsigned pos;
97  if (!cstr.findVar(std::get<0>(it), &pos))
98  pos = cstr.appendSymbolVar(std::get<0>(it));
99  cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it));
100  // Note: addBound bounds are inclusive, but specified UB is exclusive.
101  cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1);
102  }
103 
104  // Transform all targets.
105  SmallVector<Operation *> targets;
106  for (Operation *target : state.getPayloadOps(getTarget())) {
107  if (!isa<AffineMinOp, AffineMaxOp>(target)) {
108  auto diag = emitDefiniteFailure()
109  << "target must be affine.min or affine.max";
110  diag.attachNote(target->getLoc()) << "target op";
111  return diag;
112  }
113  if (boundedOps.contains(target)) {
114  auto diag = emitDefiniteFailure()
115  << "target op result must not be constrainted";
116  diag.attachNote(target->getLoc()) << "target/constrained op";
117  return diag;
118  }
119  targets.push_back(target);
120  }
121  SmallVector<Operation *> transformed;
122  RewritePatternSet patterns(getContext());
123  // Canonicalization patterns are needed so that affine.apply ops are composed
124  // with the remaining affine.min/max ops.
125  AffineMaxOp::getCanonicalizationPatterns(patterns, getContext());
126  AffineMinOp::getCanonicalizationPatterns(patterns, getContext());
127  patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
128  SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
129  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
130  GreedyRewriteConfig config;
131  config.listener =
132  static_cast<RewriterBase::Listener *>(rewriter.getListener());
134  // Apply the simplification pattern to a fixpoint.
135  if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) {
136  auto diag = emitDefiniteFailure()
137  << "affine.min/max simplification did not converge";
138  return diag;
139  }
141 }
142 
143 void SimplifyBoundedAffineOpsOp::getEffects(
145  consumesHandle(getTarget(), effects);
146  for (Value v : getBoundedValues())
147  onlyReadsHandle(v, effects);
148  modifiesPayload(effects);
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // Transform op registration
153 //===----------------------------------------------------------------------===//
154 
155 namespace {
156 class AffineTransformDialectExtension
158  AffineTransformDialectExtension> {
159 public:
160  using Base::Base;
161 
162  void init() {
163  declareGeneratedDialect<AffineDialect>();
164 
165  registerTransformOps<
166 #define GET_OP_LIST
167 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
168  >();
169  }
170 };
171 } // namespace
172 
173 #define GET_OP_CLASSES
174 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
175 
177  DialectRegistry &registry) {
178  registry.addExtensions<AffineTransformDialectExtension>();
179 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
bool findVar(Value val, unsigned *pos, unsigned offset=0) const
Looks up the position of the variable with the specified Value starting with variables at offset offs...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:534
bool isIndex() const
Definition: Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
LogicalResult addBound(presburger::BoundType type, unsigned pos, AffineMap boundMap, ValueRange operands)
Adds a bound for the variable at the specified position with constraints being drawn from the specifi...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
void registerTransformDialectExtension(DialectRegistry &registry)
FailureOr< AffineValueMap > simplifyConstrainedMinMaxOp(Operation *op, FlatAffineValueConstraints constraints)
Try to simplify the given affine.min or affine.max op to an affine map with a single result and opera...
Definition: Utils.cpp:2066
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357