MLIR 22.0.0git
AffineTransformOps.cpp
Go to the documentation of this file.
1//=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
19
20using namespace mlir;
21using namespace mlir::affine;
22using namespace mlir::transform;
23
24//===----------------------------------------------------------------------===//
25// SimplifyBoundedAffineOpsOp
26//===----------------------------------------------------------------------===//
27
28LogicalResult SimplifyBoundedAffineOpsOp::verify() {
29 if (getLowerBounds().size() != getBoundedValues().size())
30 return emitOpError() << "incorrect number of lower bounds, expected "
31 << getBoundedValues().size() << " but found "
32 << getLowerBounds().size();
33 if (getUpperBounds().size() != getBoundedValues().size())
34 return emitOpError() << "incorrect number of upper bounds, expected "
35 << getBoundedValues().size() << " but found "
36 << getUpperBounds().size();
37 return success();
38}
39
40namespace {
41/// Simplify affine.min / affine.max ops with the given constraints. They are
42/// either rewritten to affine.apply or left unchanged.
43template <typename OpTy>
44struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
45 using OpRewritePattern<OpTy>::OpRewritePattern;
46 SimplifyAffineMinMaxOp(MLIRContext *ctx,
47 const FlatAffineValueConstraints &constraints,
48 PatternBenefit benefit = 1)
49 : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
50
51 LogicalResult matchAndRewrite(OpTy op,
52 PatternRewriter &rewriter) const override {
53 FailureOr<AffineValueMap> simplified =
54 simplifyConstrainedMinMaxOp(op, constraints);
55 if (failed(simplified))
56 return failure();
57 rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
58 simplified->getOperands());
59 return success();
60 }
61
62 const FlatAffineValueConstraints &constraints;
63};
64} // namespace
65
67SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
68 TransformResults &results,
69 TransformState &state) {
70 // Get constraints for bounded values.
73 SmallVector<Value> boundedValues;
74 DenseSet<Operation *> boundedOps;
75 for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
76 getUpperBounds())) {
77 Value handle = std::get<0>(it);
78 for (Operation *op : state.getPayloadOps(handle)) {
79 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
80 auto diag =
82 << "expected bounded value handle to point to one or multiple "
83 "single-result index-typed ops";
84 diag.attachNote(op->getLoc()) << "multiple/non-index result";
85 return diag;
86 }
87 boundedValues.push_back(op->getResult(0));
88 boundedOps.insert(op);
89 lbs.push_back(std::get<1>(it));
90 ubs.push_back(std::get<2>(it));
91 }
92 }
93
94 // Build constraint set.
96 for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) {
97 unsigned pos;
98 if (!cstr.findVar(std::get<0>(it), &pos))
99 pos = cstr.appendSymbolVar(std::get<0>(it));
100 cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it));
101 // Note: addBound bounds are inclusive, but specified UB is exclusive.
102 cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1);
103 }
104
105 // Transform all targets.
107 for (Operation *target : state.getPayloadOps(getTarget())) {
108 if (!isa<AffineMinOp, AffineMaxOp>(target)) {
110 << "target must be affine.min or affine.max";
111 diag.attachNote(target->getLoc()) << "target op";
112 return diag;
113 }
114 if (boundedOps.contains(target)) {
116 << "target op result must not be constrained";
117 diag.attachNote(target->getLoc()) << "target/constrained op";
118 return diag;
119 }
120 targets.push_back(target);
121 }
123 // Canonicalization patterns are needed so that affine.apply ops are composed
124 // with the remaining affine.min/max ops.
125 AffineMaxOp::getCanonicalizationPatterns(patterns, getContext());
126 AffineMinOp::getCanonicalizationPatterns(patterns, getContext());
127 patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
128 SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
129 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
130 // Apply the simplification pattern to a fixpoint.
132 targets, frozenPatterns,
134 .setListener(
135 static_cast<RewriterBase::Listener *>(rewriter.getListener()))
138 << "affine.min/max simplification did not converge";
139 return diag;
140 }
142}
143
144void SimplifyBoundedAffineOpsOp::getEffects(
146 consumesHandle(getTargetMutable(), effects);
147 for (OpOperand &operand : getBoundedValuesMutable())
148 onlyReadsHandle(operand, effects);
149 modifiesPayload(effects);
150}
151
152//===----------------------------------------------------------------------===//
153// SimplifyMinMaxAffineOpsOp
154//===----------------------------------------------------------------------===//
156SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter,
157 TransformResults &results,
158 TransformState &state) {
160 for (Operation *target : state.getPayloadOps(getTarget())) {
161 if (!isa<AffineMinOp, AffineMaxOp>(target)) {
163 << "target must be affine.min or affine.max";
164 diag.attachNote(target->getLoc()) << "target op";
165 return diag;
166 }
167 targets.push_back(target);
168 }
169 bool modified = false;
170 if (failed(mlir::affine::simplifyAffineMinMaxOps(rewriter, targets,
171 &modified))) {
172 return emitDefiniteFailure()
173 << "affine.min/max simplification did not converge";
174 }
175 if (!modified) {
176 return emitSilenceableError()
177 << "the transform failed to simplify any of the target operations";
178 }
180}
181
182void SimplifyMinMaxAffineOpsOp::getEffects(
184 consumesHandle(getTargetMutable(), effects);
185 modifiesPayload(effects);
186}
187
188//===----------------------------------------------------------------------===//
189// Transform op registration
190//===----------------------------------------------------------------------===//
191
192namespace {
193class AffineTransformDialectExtension
195 AffineTransformDialectExtension> {
196public:
197 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)
198
199 using Base::Base;
200
201 void init() {
202 declareGeneratedDialect<AffineDialect>();
203
204 registerTransformOps<
205#define GET_OP_LIST
206#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
207 >();
208 }
209};
210} // namespace
211
212#define GET_OP_CLASSES
213#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
214
216 DialectRegistry &registry) {
217 registry.addExtensions<AffineTransformDialectExtension>();
218}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
b getContext())
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
bool findVar(Value val, unsigned *pos, unsigned offset=0) const
Looks up the position of the variable with the specified Value starting with variables at offset offs...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setStrictness(GreedyRewriteStrictness mode)
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
LogicalResult addBound(presburger::BoundType type, unsigned pos, AffineMap boundMap, ValueRange operands)
Adds a bound for the variable at the specified position with constraints being drawn from the specifi...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
void registerTransformDialectExtension(DialectRegistry &registry)
LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter, ArrayRef< Operation * > ops, bool *modified=nullptr)
This transform applies simplifyAffineMinOp and simplifyAffineMaxOp to all the affine....
FailureOr< AffineValueMap > simplifyConstrainedMinMaxOp(Operation *op, FlatAffineValueConstraints constraints)
Try to simplify the given affine.min or affine.max op to an affine map with a single result and opera...
Definition Utils.cpp:2286
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...