MLIR  20.0.0git
AffineTransformOps.cpp
Go to the documentation of this file.
1 //=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 
19 using namespace mlir;
20 using namespace mlir::affine;
21 using namespace mlir::transform;
22 
23 //===----------------------------------------------------------------------===//
24 // SimplifyBoundedAffineOpsOp
25 //===----------------------------------------------------------------------===//
26 
27 LogicalResult SimplifyBoundedAffineOpsOp::verify() {
28  if (getLowerBounds().size() != getBoundedValues().size())
29  return emitOpError() << "incorrect number of lower bounds, expected "
30  << getBoundedValues().size() << " but found "
31  << getLowerBounds().size();
32  if (getUpperBounds().size() != getBoundedValues().size())
33  return emitOpError() << "incorrect number of upper bounds, expected "
34  << getBoundedValues().size() << " but found "
35  << getUpperBounds().size();
36  return success();
37 }
38 
39 namespace {
40 /// Simplify affine.min / affine.max ops with the given constraints. They are
41 /// either rewritten to affine.apply or left unchanged.
42 template <typename OpTy>
43 struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
45  SimplifyAffineMinMaxOp(MLIRContext *ctx,
46  const FlatAffineValueConstraints &constraints,
47  PatternBenefit benefit = 1)
48  : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
49 
50  LogicalResult matchAndRewrite(OpTy op,
51  PatternRewriter &rewriter) const override {
52  FailureOr<AffineValueMap> simplified =
53  simplifyConstrainedMinMaxOp(op, constraints);
54  if (failed(simplified))
55  return failure();
56  rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
57  simplified->getOperands());
58  return success();
59  }
60 
61  const FlatAffineValueConstraints &constraints;
62 };
63 } // namespace
64 
66 SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
67  TransformResults &results,
68  TransformState &state) {
69  // Get constraints for bounded values.
72  SmallVector<Value> boundedValues;
73  DenseSet<Operation *> boundedOps;
74  for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
75  getUpperBounds())) {
76  Value handle = std::get<0>(it);
77  for (Operation *op : state.getPayloadOps(handle)) {
78  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
79  auto diag =
81  << "expected bounded value handle to point to one or multiple "
82  "single-result index-typed ops";
83  diag.attachNote(op->getLoc()) << "multiple/non-index result";
84  return diag;
85  }
86  boundedValues.push_back(op->getResult(0));
87  boundedOps.insert(op);
88  lbs.push_back(std::get<1>(it));
89  ubs.push_back(std::get<2>(it));
90  }
91  }
92 
93  // Build constraint set.
95  for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) {
96  unsigned pos;
97  if (!cstr.findVar(std::get<0>(it), &pos))
98  pos = cstr.appendSymbolVar(std::get<0>(it));
99  cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it));
100  // Note: addBound bounds are inclusive, but specified UB is exclusive.
101  cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1);
102  }
103 
104  // Transform all targets.
105  SmallVector<Operation *> targets;
106  for (Operation *target : state.getPayloadOps(getTarget())) {
107  if (!isa<AffineMinOp, AffineMaxOp>(target)) {
108  auto diag = emitDefiniteFailure()
109  << "target must be affine.min or affine.max";
110  diag.attachNote(target->getLoc()) << "target op";
111  return diag;
112  }
113  if (boundedOps.contains(target)) {
114  auto diag = emitDefiniteFailure()
115  << "target op result must not be constrainted";
116  diag.attachNote(target->getLoc()) << "target/constrained op";
117  return diag;
118  }
119  targets.push_back(target);
120  }
121  SmallVector<Operation *> transformed;
123  // Canonicalization patterns are needed so that affine.apply ops are composed
124  // with the remaining affine.min/max ops.
125  AffineMaxOp::getCanonicalizationPatterns(patterns, getContext());
126  AffineMinOp::getCanonicalizationPatterns(patterns, getContext());
127  patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
128  SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
129  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
131  config.listener =
132  static_cast<RewriterBase::Listener *>(rewriter.getListener());
134  // Apply the simplification pattern to a fixpoint.
135  if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
136  auto diag = emitDefiniteFailure()
137  << "affine.min/max simplification did not converge";
138  return diag;
139  }
141 }
142 
143 void SimplifyBoundedAffineOpsOp::getEffects(
145  consumesHandle(getTargetMutable(), effects);
146  for (OpOperand &operand : getBoundedValuesMutable())
147  onlyReadsHandle(operand, effects);
148  modifiesPayload(effects);
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // Transform op registration
153 //===----------------------------------------------------------------------===//
154 
155 namespace {
156 class AffineTransformDialectExtension
158  AffineTransformDialectExtension> {
159 public:
160  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)
161 
162  using Base::Base;
163 
164  void init() {
165  declareGeneratedDialect<AffineDialect>();
166 
167  registerTransformOps<
168 #define GET_OP_LIST
169 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
170  >();
171  }
172 };
173 } // namespace
174 
175 #define GET_OP_CLASSES
176 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
177 
179  DialectRegistry &registry) {
180  registry.addExtensions<AffineTransformDialectExtension>();
181 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
bool findVar(Value val, unsigned *pos, unsigned offset=0) const
Looks up the position of the variable with the specified Value starting with variables at offset offs...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:329
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
LogicalResult addBound(presburger::BoundType type, unsigned pos, AffineMap boundMap, ValueRange operands)
Adds a bound for the variable at the specified position with constraints being drawn from the specifi...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
void registerTransformDialectExtension(DialectRegistry &registry)
FailureOr< AffineValueMap > simplifyConstrainedMinMaxOp(Operation *op, FlatAffineValueConstraints constraints)
Try to simplify the given affine.min or affine.max op to an affine map with a single result and opera...
Definition: Utils.cpp:2066
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358