MLIR 22.0.0git
MemRefTransformOps.cpp
Go to the documentation of this file.
1//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
27#include "llvm/Support/Debug.h"
28
29using namespace mlir;
30
31#define DEBUG_TYPE "memref-transforms"
32#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
33
34//===----------------------------------------------------------------------===//
35// Apply...ConversionPatternsOp
36//===----------------------------------------------------------------------===//
37
38std::unique_ptr<TypeConverter>
39transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
41 options.allocLowering =
44 options.useGenericFunctions = getUseGenericFunctions();
45
47 options.overrideIndexBitwidth(getIndexBitwidth());
48
49 // TODO: the following two options don't really make sense for
50 // memref_to_llvm_type_converter specifically but we should have a single
51 // to_llvm_type_converter.
52 if (getDataLayout().has_value())
53 options.dataLayout = llvm::DataLayout(getDataLayout().value());
54 options.useBarePtrCallConv = getUseBarePtrCallConv();
55
56 return std::make_unique<LLVMTypeConverter>(getContext(), options);
57}
58
59StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
60 return "LLVMTypeConverter";
61}
62
63//===----------------------------------------------------------------------===//
64// Apply...PatternsOp
65//===----------------------------------------------------------------------===//
66
67namespace {
68class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
69public:
70 explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
71 : OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
72 dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
73
74 LogicalResult matchAndRewrite(memref::AllocOp op,
75 PatternRewriter &rewriter) const override {
77 rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
78 MemRefType type = alloc.getMemref().getType();
79 if (!type.hasStaticShape())
80 return false;
81
82 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
83 int64_t elementSize = dataLayout.getTypeSize(type.getElementType());
84 return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
85 }));
86 }
87
88private:
89 DataLayoutAnalysis dataLayoutAnalysis;
90 int64_t maxSize;
91};
92} // namespace
93
94void transform::ApplyAllocToAllocaOp::populatePatterns(
96
97void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
99 patterns.insert<AllocToAllocaPattern>(
100 state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
101}
102
103void transform::ApplyExpandOpsPatternsOp::populatePatterns(
106}
107
108void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
111}
112
113void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
116}
117
118void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
121}
122
123void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
124 populatePatterns(RewritePatternSet &patterns) {
126}
127
128//===----------------------------------------------------------------------===//
129// AllocaToGlobalOp
130//===----------------------------------------------------------------------===//
131
133transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
136 auto allocaOps = state.getPayloadOps(getAlloca());
137
140
141 // Transform `memref.alloca`s.
142 for (auto *op : allocaOps) {
143 auto alloca = cast<memref::AllocaOp>(op);
144 MLIRContext *ctx = rewriter.getContext();
145 Location loc = alloca->getLoc();
146
147 memref::GlobalOp globalOp;
148 {
149 // Find nearest symbol table.
151 assert(symbolTableOp && "expected alloca payload to be in symbol table");
152 SymbolTable symbolTable(symbolTableOp);
153
154 // Insert a `memref.global` into the symbol table.
155 Type resultType = alloca.getResult().getType();
156 OpBuilder builder(rewriter.getContext());
157 // TODO: Add a better builder for this.
158 globalOp = memref::GlobalOp::create(
159 builder, loc, StringAttr::get(ctx, "alloca"),
160 StringAttr::get(ctx, "private"), TypeAttr::get(resultType),
161 Attribute{}, UnitAttr{}, IntegerAttr{});
162 symbolTable.insert(globalOp);
163 }
164
165 // Replace the `memref.alloca` with a `memref.get_global` accessing the
166 // global symbol inserted above.
167 rewriter.setInsertionPoint(alloca);
168 auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
169 alloca, globalOp.getType(), globalOp.getName());
170
171 globalOps.push_back(globalOp);
172 getGlobalOps.push_back(getGlobalOp);
173 }
174
175 // Assemble results.
176 results.set(cast<OpResult>(getGlobal()), globalOps);
177 results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
178
180}
181
182void transform::MemRefAllocaToGlobalOp::getEffects(
184 producesHandle(getOperation()->getOpResults(), effects);
185 consumesHandle(getAllocaMutable(), effects);
186 modifiesPayload(effects);
187}
188
189//===----------------------------------------------------------------------===//
190// MemRefMultiBufferOp
191//===----------------------------------------------------------------------===//
192
193DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
195 transform::TransformResults &transformResults,
198 for (Operation *op : state.getPayloadOps(getTarget())) {
199 bool canApplyMultiBuffer = true;
200 auto target = cast<memref::AllocOp>(op);
201 LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
202 // Skip allocations not used in a loop.
203 for (Operation *user : target->getUsers()) {
204 if (isa<memref::DeallocOp>(user))
205 continue;
206 auto loop = user->getParentOfType<LoopLikeOpInterface>();
207 if (!loop) {
208 LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
209 DBGS() << "----due to user: " << *user;);
210 canApplyMultiBuffer = false;
211 break;
212 }
213 }
214 if (!canApplyMultiBuffer) {
215 LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
216 continue;
217 }
218
219 auto newBuffer =
220 memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis());
221
222 if (failed(newBuffer)) {
223 LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
224 return emitSilenceableFailure(target->getLoc())
225 << "op failed to multibuffer";
226 }
227
228 results.push_back(*newBuffer);
229 }
230 transformResults.set(cast<OpResult>(getResult()), results);
232}
233
234//===----------------------------------------------------------------------===//
235// MemRefEraseDeadAllocAndStoresOp
236//===----------------------------------------------------------------------===//
237
239transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
243 // Apply store to load forwarding and dead store elimination.
244 vector::transferOpflowOpt(rewriter, target);
247}
248
249void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
251 transform::onlyReadsHandle(getTargetMutable(), effects);
253}
254void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
256 Value target) {
257 result.addOperands(target);
258}
259
260//===----------------------------------------------------------------------===//
261// MemRefMakeLoopIndependentOp
262//===----------------------------------------------------------------------===//
263
264DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne(
268 // Gather IVs.
270 Operation *nextOp = target;
271 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
272 nextOp = nextOp->getParentOfType<scf::ForOp>();
273 if (!nextOp) {
274 DiagnosedSilenceableFailure diag = emitSilenceableError()
275 << "could not find " << i
276 << "-th enclosing loop";
277 diag.attachNote(target->getLoc()) << "target op";
278 return diag;
279 }
280 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
281 }
282
283 // Rewrite IR.
284 FailureOr<Value> replacement = failure();
285 if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
286 replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs);
287 } else {
288 DiagnosedSilenceableFailure diag = emitSilenceableError()
289 << "unsupported target op";
290 diag.attachNote(target->getLoc()) << "target op";
291 return diag;
292 }
293 if (failed(replacement)) {
295 emitSilenceableError() << "could not make target op loop-independent";
296 diag.attachNote(target->getLoc()) << "target op";
297 return diag;
298 }
299 results.push_back(replacement->getDefiningOp());
301}
302
303//===----------------------------------------------------------------------===//
304// Transform op registration
305//===----------------------------------------------------------------------===//
306
307namespace {
308class MemRefTransformDialectExtension
310 MemRefTransformDialectExtension> {
311public:
312 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemRefTransformDialectExtension)
313
314 using Base::Base;
315
316 void init() {
317 declareGeneratedDialect<affine::AffineDialect>();
318 declareGeneratedDialect<arith::ArithDialect>();
319 declareGeneratedDialect<memref::MemRefDialect>();
320 declareGeneratedDialect<nvgpu::NVGPUDialect>();
321 declareGeneratedDialect<vector::VectorDialect>();
322
323 registerTransformOps<
324#define GET_OP_LIST
325#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
326 >();
327 }
328};
329} // namespace
330
331#define GET_OP_CLASSES
332#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
333
335 DialectRegistry &registry) {
336 registry.addExtensions<MemRefTransformDialectExtension>();
337}
return success()
static uint64_t getIndexBitwidth(DataLayoutEntryListRef params)
Returns the bitwidth of the index type if specified in the param list.
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
#define DBGS()
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Definition Builders.h:56
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
Track temporary allocations that are never read from.
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
void registerTransformDialectExtension(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.