MLIR 23.0.0git
AffineTransformOps.cpp
Go to the documentation of this file.
1//=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
21#include "llvm/ADT/ArrayRef.h"
22#include <cstdint>
23
24using namespace mlir;
25using namespace mlir::affine;
26using namespace mlir::transform;
27
28//===----------------------------------------------------------------------===//
29// SimplifyBoundedAffineOpsOp
30//===----------------------------------------------------------------------===//
31
32LogicalResult SimplifyBoundedAffineOpsOp::verify() {
33 if (getLowerBounds().size() != getBoundedValues().size())
34 return emitOpError() << "incorrect number of lower bounds, expected "
35 << getBoundedValues().size() << " but found "
36 << getLowerBounds().size();
37 if (getUpperBounds().size() != getBoundedValues().size())
38 return emitOpError() << "incorrect number of upper bounds, expected "
39 << getBoundedValues().size() << " but found "
40 << getUpperBounds().size();
41 return success();
42}
43
44namespace {
45/// Simplify affine.min / affine.max ops with the given constraints. They are
46/// either rewritten to affine.apply or left unchanged.
47template <typename OpTy>
48struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
49 using OpRewritePattern<OpTy>::OpRewritePattern;
50 SimplifyAffineMinMaxOp(MLIRContext *ctx,
51 const FlatAffineValueConstraints &constraints,
52 PatternBenefit benefit = 1)
53 : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
54
55 LogicalResult matchAndRewrite(OpTy op,
56 PatternRewriter &rewriter) const override {
57 FailureOr<AffineValueMap> simplified =
58 simplifyConstrainedMinMaxOp(op, constraints);
59 if (failed(simplified))
60 return failure();
61 rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
62 simplified->getOperands());
63 return success();
64 }
65
66 const FlatAffineValueConstraints &constraints;
67};
68} // namespace
69
71SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
72 TransformResults &results,
73 TransformState &state) {
74 // Get constraints for bounded values.
77 SmallVector<Value> boundedValues;
78 DenseSet<Operation *> boundedOps;
79 for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
80 getUpperBounds())) {
81 Value handle = std::get<0>(it);
82 for (Operation *op : state.getPayloadOps(handle)) {
83 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
84 auto diag =
86 << "expected bounded value handle to point to one or multiple "
87 "single-result index-typed ops";
88 diag.attachNote(op->getLoc()) << "multiple/non-index result";
89 return diag;
90 }
91 boundedValues.push_back(op->getResult(0));
92 boundedOps.insert(op);
93 lbs.push_back(std::get<1>(it));
94 ubs.push_back(std::get<2>(it));
95 }
96 }
97
98 // Build constraint set.
100 for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) {
101 unsigned pos;
102 if (!cstr.findVar(std::get<0>(it), &pos))
103 pos = cstr.appendSymbolVar(std::get<0>(it));
104 cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it));
105 // Note: addBound bounds are inclusive, but specified UB is exclusive.
106 cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1);
107 }
108
109 // Transform all targets.
111 for (Operation *target : state.getPayloadOps(getTarget())) {
112 if (!isa<AffineMinOp, AffineMaxOp>(target)) {
114 << "target must be affine.min or affine.max";
115 diag.attachNote(target->getLoc()) << "target op";
116 return diag;
117 }
118 if (boundedOps.contains(target)) {
120 << "target op result must not be constrained";
121 diag.attachNote(target->getLoc()) << "target/constrained op";
122 return diag;
123 }
124 targets.push_back(target);
125 }
127 // Canonicalization patterns are needed so that affine.apply ops are composed
128 // with the remaining affine.min/max ops.
129 AffineMaxOp::getCanonicalizationPatterns(patterns, getContext());
130 AffineMinOp::getCanonicalizationPatterns(patterns, getContext());
131 patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
132 SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
133 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
134 // Apply the simplification pattern to a fixpoint.
136 targets, frozenPatterns,
138 .setListener(
139 static_cast<RewriterBase::Listener *>(rewriter.getListener()))
142 << "affine.min/max simplification did not converge";
143 return diag;
144 }
146}
147
148void SimplifyBoundedAffineOpsOp::getEffects(
150 consumesHandle(getTargetMutable(), effects);
151 for (OpOperand &operand : getBoundedValuesMutable())
152 onlyReadsHandle(operand, effects);
153 modifiesPayload(effects);
154}
155
156//===----------------------------------------------------------------------===//
157// SimplifyMinMaxAffineOpsOp
158//===----------------------------------------------------------------------===//
159
160LogicalResult SuperVectorizeOp::verify() {
161 if (getFastestVaryingPattern().has_value()) {
162 if (getFastestVaryingPattern()->size() != getVectorSizes().size())
163 return emitOpError()
164 << "fastest varying pattern specified with different size than "
165 "the vector size";
166 }
167 return success();
168}
169
171SuperVectorizeOp::apply(transform::TransformRewriter &rewriter,
172 TransformResults &results, TransformState &state) {
173 ArrayRef<int64_t> fastestVaryingPattern;
174 if (getFastestVaryingPattern().has_value())
175 fastestVaryingPattern = getFastestVaryingPattern().value();
176
177 for (Operation *target : state.getPayloadOps(getTarget()))
178 if (!target->getParentOfType<affine::AffineForOp>())
179 vectorizeChildAffineLoops(target, getVectorizeReductions(),
180 getVectorSizes(), fastestVaryingPattern);
181
183}
184
185void SuperVectorizeOp::getEffects(
187 consumesHandle(getTargetMutable(), effects);
188 modifiesPayload(effects);
189}
190
191//===----------------------------------------------------------------------===//
192// SimplifyMinMaxAffineOpsOp
193//===----------------------------------------------------------------------===//
195SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter,
196 TransformResults &results,
197 TransformState &state) {
199 for (Operation *target : state.getPayloadOps(getTarget())) {
200 if (!isa<AffineMinOp, AffineMaxOp>(target)) {
202 << "target must be affine.min or affine.max";
203 diag.attachNote(target->getLoc()) << "target op";
204 return diag;
205 }
206 targets.push_back(target);
207 }
208 bool modified = false;
209 if (failed(mlir::affine::simplifyAffineMinMaxOps(rewriter, targets,
210 &modified))) {
211 return emitDefiniteFailure()
212 << "affine.min/max simplification did not converge";
213 }
214 if (!modified) {
215 return emitSilenceableError()
216 << "the transform failed to simplify any of the target operations";
217 }
219}
220
221void SimplifyMinMaxAffineOpsOp::getEffects(
223 consumesHandle(getTargetMutable(), effects);
224 modifiesPayload(effects);
225}
226
227//===----------------------------------------------------------------------===//
228// Transform op registration
229//===----------------------------------------------------------------------===//
230
231namespace {
232class AffineTransformDialectExtension
234 AffineTransformDialectExtension> {
235public:
236 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)
237
238 using Base::Base;
239
240 void init() {
241 declareGeneratedDialect<AffineDialect>();
242 declareGeneratedDialect<vector::VectorDialect>();
243
244 registerTransformOps<
245#define GET_OP_LIST
246#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
247 >();
248 }
249};
250} // namespace
251
252#define GET_OP_CLASSES
253#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
254
256 DialectRegistry &registry) {
257 registry.addExtensions<AffineTransformDialectExtension>();
258}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
b getContext())
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
bool findVar(Value val, unsigned *pos, unsigned offset=0) const
Looks up the position of the variable with the specified Value starting with variables at offset offs...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setStrictness(GreedyRewriteStrictness mode)
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
FlatAffineValueConstraints is an extension of FlatLinearValueConstraints with helper functions for Af...
LogicalResult addBound(presburger::BoundType type, unsigned pos, AffineMap boundMap, ValueRange operands)
Adds a bound for the variable at the specified position with constraints being drawn from the specifi...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
void registerTransformDialectExtension(DialectRegistry &registry)
LogicalResult simplifyAffineMinMaxOps(RewriterBase &rewriter, ArrayRef< Operation * > ops, bool *modified=nullptr)
This transform applies simplifyAffineMinOp and simplifyAffineMaxOp to all the affine....
FailureOr< AffineValueMap > simplifyConstrainedMinMaxOp(Operation *op, FlatAffineValueConstraints constraints)
Try to simplify the given affine.min or affine.max op to an affine map with a single result and opera...
Definition Utils.cpp:2286
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...