28 #include "llvm/Support/Debug.h"
32 #define DEBUG_TYPE "memref-transforms"
33 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
39 std::unique_ptr<TypeConverter>
40 transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
45 options.useGenericFunctions = getUseGenericFunctions();
53 if (getDataLayout().has_value())
54 options.dataLayout = llvm::DataLayout(getDataLayout().value());
55 options.useBarePtrCallConv = getUseBarePtrCallConv();
60 StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
61 return "LLVMTypeConverter";
71 explicit AllocToAllocaPattern(
Operation *analysisRoot, int64_t maxSize = 0)
73 dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
75 LogicalResult matchAndRewrite(memref::AllocOp op,
78 rewriter, op, [
this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
79 MemRefType type = alloc.getMemref().
getType();
80 if (!type.hasStaticShape())
83 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
84 int64_t elementSize = dataLayout.
getTypeSize(type.getElementType());
85 return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
95 void transform::ApplyAllocToAllocaOp::populatePatterns(
98 void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
100 patterns.
insert<AllocToAllocaPattern>(
101 state.getTopLevel(),
static_cast<int64_t
>(getSizeLimit().value_or(0)));
104 void transform::ApplyExpandOpsPatternsOp::populatePatterns(
109 void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
114 void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
119 void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
124 void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
137 auto allocaOps = state.getPayloadOps(getAlloca());
143 for (
auto *op : allocaOps) {
144 auto alloca = cast<memref::AllocaOp>(op);
148 memref::GlobalOp globalOp;
152 assert(symbolTableOp &&
"expected alloca payload to be in symbol table");
156 Type resultType = alloca.getResult().getType();
159 globalOp = builder.create<memref::GlobalOp>(
162 symbolTable.insert(globalOp);
169 alloca, globalOp.getType(), globalOp.getName());
171 globalOps.push_back(globalOp);
172 getGlobalOps.push_back(getGlobalOp);
176 results.
set(cast<OpResult>(getGlobal()), globalOps);
177 results.
set(cast<OpResult>(getGetGlobal()), getGlobalOps);
182 void transform::MemRefAllocaToGlobalOp::getEffects(
198 for (
Operation *op : state.getPayloadOps(getTarget())) {
199 bool canApplyMultiBuffer =
true;
200 auto target = cast<memref::AllocOp>(op);
201 LLVM_DEBUG(
DBGS() <<
"Start multibuffer transform op: " << target <<
"\n";);
204 if (isa<memref::DeallocOp>(user))
206 auto loop = user->getParentOfType<LoopLikeOpInterface>();
208 LLVM_DEBUG(
DBGS() <<
"--allocation not used in a loop\n";
209 DBGS() <<
"----due to user: " << *user;);
210 canApplyMultiBuffer =
false;
214 if (!canApplyMultiBuffer) {
215 LLVM_DEBUG(
DBGS() <<
"--cannot apply multibuffering -> Skip\n";);
222 if (failed(newBuffer)) {
223 LLVM_DEBUG(
DBGS() <<
"--op failed to multibuffer\n";);
225 <<
"op failed to multibuffer";
228 results.push_back(*newBuffer);
230 transformResults.
set(cast<OpResult>(getResult()), results);
239 transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
249 void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
254 void transform::MemRefEraseDeadAllocAndStoresOp::build(
OpBuilder &builder,
271 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
275 <<
"could not find " << i
276 <<
"-th enclosing loop";
277 diag.attachNote(target->
getLoc()) <<
"target op";
280 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
284 FailureOr<Value> replacement = failure();
285 if (
auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
289 <<
"unsupported target op";
290 diag.attachNote(target->
getLoc()) <<
"target op";
293 if (failed(replacement)) {
295 emitSilenceableError() <<
"could not make target op loop-independent";
296 diag.attachNote(target->
getLoc()) <<
"target op";
299 results.
push_back(replacement->getDefiningOp());
308 class MemRefTransformDialectExtension
310 MemRefTransformDialectExtension> {
317 declareGeneratedDialect<affine::AffineDialect>();
318 declareGeneratedDialect<arith::ArithDialect>();
319 declareGeneratedDialect<memref::MemRefDialect>();
320 declareGeneratedDialect<nvgpu::NVGPUDialect>();
321 declareGeneratedDialect<vector::VectorDialect>();
323 registerTransformOps<
325 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
331 #define GET_OP_CLASSES
332 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
static uint64_t getIndexBitwidth(DataLayoutEntryListRef params)
Returns the bitwidth of the index type if specified in the param list.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
Track temporary allocations that are never read from.
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
void registerTransformDialectExtension(DialectRegistry ®istry)
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)