28 #include "llvm/Support/Debug.h"
32 #define DEBUG_TYPE "memref-transforms"
33 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
39 std::unique_ptr<TypeConverter>
40 transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
45 options.useGenericFunctions = getUseGenericFunctions();
53 if (getDataLayout().has_value())
54 options.dataLayout = llvm::DataLayout(getDataLayout().value());
55 options.useBarePtrCallConv = getUseBarePtrCallConv();
60 StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
61 return "LLVMTypeConverter";
71 explicit AllocToAllocaPattern(
Operation *analysisRoot, int64_t maxSize = 0)
73 dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
78 rewriter, op, [
this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
79 MemRefType type = alloc.getMemref().getType();
80 if (!type.hasStaticShape())
83 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
84 int64_t elementSize = dataLayout.
getTypeSize(type.getElementType());
85 return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
95 void transform::ApplyAllocToAllocaOp::populatePatterns(
98 void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
100 patterns.
insert<AllocToAllocaPattern>(
101 state.getTopLevel(),
static_cast<int64_t
>(getSizeLimit().value_or(0)));
104 void transform::ApplyExpandOpsPatternsOp::populatePatterns(
109 void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
114 void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
119 void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
124 void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
137 auto allocaOps = state.getPayloadOps(getAlloca());
143 for (
auto *op : allocaOps) {
144 auto alloca = cast<memref::AllocaOp>(op);
148 memref::GlobalOp globalOp;
152 assert(symbolTableOp &&
"expected alloca payload to be in symbol table");
156 Type resultType = alloca.getResult().getType();
159 globalOp = builder.create<memref::GlobalOp>(
162 symbolTable.insert(globalOp);
169 alloca, globalOp.getType(), globalOp.getName());
171 globalOps.push_back(globalOp);
172 getGlobalOps.push_back(getGlobalOp);
176 results.
set(cast<OpResult>(getGlobal()), globalOps);
177 results.
set(cast<OpResult>(getGetGlobal()), getGlobalOps);
182 void transform::MemRefAllocaToGlobalOp::getEffects(
199 for (
Operation *op : state.getPayloadOps(getTarget())) {
200 bool canApplyMultiBuffer =
true;
201 auto target = cast<memref::AllocOp>(op);
202 LLVM_DEBUG(
DBGS() <<
"Start multibuffer transform op: " << target <<
"\n";);
205 if (isa<memref::DeallocOp>(user))
207 auto loop = user->getParentOfType<LoopLikeOpInterface>();
209 LLVM_DEBUG(
DBGS() <<
"--allocation not used in a loop\n";
210 DBGS() <<
"----due to user: " << *user;);
211 canApplyMultiBuffer =
false;
215 if (!canApplyMultiBuffer) {
216 LLVM_DEBUG(
DBGS() <<
"--cannot apply multibuffering -> Skip\n";);
224 LLVM_DEBUG(
DBGS() <<
"--op failed to multibuffer\n";);
226 <<
"op failed to multibuffer";
229 results.push_back(*newBuffer);
231 transformResults.
set(cast<OpResult>(getResult()), results);
240 transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
250 void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
255 void transform::MemRefEraseDeadAllocAndStoresOp::build(
OpBuilder &builder,
272 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
276 <<
"could not find " << i
277 <<
"-th enclosing loop";
278 diag.attachNote(target->
getLoc()) <<
"target op";
281 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
286 if (
auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
290 <<
"unsupported target op";
291 diag.attachNote(target->
getLoc()) <<
"target op";
294 if (
failed(replacement)) {
296 emitSilenceableError() <<
"could not make target op loop-independent";
297 diag.attachNote(target->
getLoc()) <<
"target op";
300 results.
push_back(replacement->getDefiningOp());
309 class MemRefTransformDialectExtension
311 MemRefTransformDialectExtension> {
316 declareGeneratedDialect<affine::AffineDialect>();
317 declareGeneratedDialect<arith::ArithDialect>();
318 declareGeneratedDialect<memref::MemRefDialect>();
319 declareGeneratedDialect<nvgpu::NVGPUDialect>();
320 declareGeneratedDialect<vector::VectorDialect>();
322 registerTransformOps<
324 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
330 #define GET_OP_CLASSES
331 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
static uint64_t getIndexBitwidth(DataLayoutEntryListRef params)
Returns the bitwidth of the index type if specified in the param list.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
user_range getUsers()
Returns a range of all users.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
void registerTransformDialectExtension(DialectRegistry ®istry)
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)