27 #include "llvm/Support/Debug.h"
31 #define DEBUG_TYPE "memref-transforms"
32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
38 std::unique_ptr<TypeConverter>
39 transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
44 options.useGenericFunctions = getUseGenericFunctions();
52 if (getDataLayout().has_value())
53 options.dataLayout = llvm::DataLayout(getDataLayout().value());
54 options.useBarePtrCallConv = getUseBarePtrCallConv();
59 StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
60 return "LLVMTypeConverter";
70 explicit AllocToAllocaPattern(
Operation *analysisRoot, int64_t maxSize = 0)
72 dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
77 rewriter, op, [
this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
78 MemRefType type = alloc.getMemref().getType();
79 if (!type.hasStaticShape())
82 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
83 int64_t elementSize = dataLayout.
getTypeSize(type.getElementType());
84 return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
94 void transform::ApplyAllocToAllocaOp::populatePatterns(
97 void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
99 patterns.
insert<AllocToAllocaPattern>(
100 state.getTopLevel(),
static_cast<int64_t
>(getSizeLimit().value_or(0)));
103 void transform::ApplyExpandOpsPatternsOp::populatePatterns(
108 void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
113 void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
118 void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
123 void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
136 auto allocaOps = state.getPayloadOps(getAlloca());
142 for (
auto *op : allocaOps) {
143 auto alloca = cast<memref::AllocaOp>(op);
147 memref::GlobalOp globalOp;
151 assert(symbolTableOp &&
"expected alloca payload to be in symbol table");
155 Type resultType = alloca.getResult().getType();
158 globalOp = builder.create<memref::GlobalOp>(
161 symbolTable.insert(globalOp);
168 alloca, globalOp.getType(), globalOp.getName());
170 globalOps.push_back(globalOp);
171 getGlobalOps.push_back(getGlobalOp);
175 results.
set(getGlobal().cast<
OpResult>(), globalOps);
176 results.
set(getGetGlobal().cast<
OpResult>(), getGlobalOps);
181 void transform::MemRefAllocaToGlobalOp::getEffects(
198 for (
Operation *op : state.getPayloadOps(getTarget())) {
199 bool canApplyMultiBuffer =
true;
200 auto target = cast<memref::AllocOp>(op);
201 LLVM_DEBUG(
DBGS() <<
"Start multibuffer transform op: " << target <<
"\n";);
204 if (isa<memref::DeallocOp>(user))
206 auto loop = user->getParentOfType<LoopLikeOpInterface>();
208 LLVM_DEBUG(
DBGS() <<
"--allocation not used in a loop\n";
209 DBGS() <<
"----due to user: " << *user;);
210 canApplyMultiBuffer =
false;
214 if (!canApplyMultiBuffer) {
215 LLVM_DEBUG(
DBGS() <<
"--cannot apply multibuffering -> Skip\n";);
223 LLVM_DEBUG(
DBGS() <<
"--op failed to multibuffer\n";);
225 <<
"op failed to multibuffer";
228 results.push_back(*newBuffer);
230 transformResults.
set(cast<OpResult>(getResult()), results);
239 transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
249 void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
254 void transform::MemRefEraseDeadAllocAndStoresOp::build(
OpBuilder &builder,
271 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
275 <<
"could not find " << i
276 <<
"-th enclosing loop";
277 diag.attachNote(target->
getLoc()) <<
"target op";
280 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
285 if (
auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
289 <<
"unsupported target op";
290 diag.attachNote(target->
getLoc()) <<
"target op";
293 if (
failed(replacement)) {
295 emitSilenceableError() <<
"could not make target op loop-independent";
296 diag.attachNote(target->
getLoc()) <<
"target op";
299 results.
push_back(replacement->getDefiningOp());
308 class MemRefTransformDialectExtension
310 MemRefTransformDialectExtension> {
315 declareGeneratedDialect<affine::AffineDialect>();
316 declareGeneratedDialect<arith::ArithDialect>();
317 declareGeneratedDialect<memref::MemRefDialect>();
318 declareGeneratedDialect<nvgpu::NVGPUDialect>();
319 declareGeneratedDialect<vector::VectorDialect>();
321 registerTransformOps<
323 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
329 #define GET_OP_CLASSES
330 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
static uint64_t getIndexBitwidth(DataLayoutEntryListRef params)
Returns the bitwidth of the index type if specified in the param list.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
void registerTransformDialectExtension(DialectRegistry ®istry)
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)