27#include "llvm/Support/Debug.h"
31#define DEBUG_TYPE "memref-transforms"
32#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
38std::unique_ptr<TypeConverter>
39transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
44 options.useGenericFunctions = getUseGenericFunctions();
52 if (getDataLayout().has_value())
53 options.dataLayout = llvm::DataLayout(getDataLayout().value());
54 options.useBarePtrCallConv = getUseBarePtrCallConv();
59StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
60 return "LLVMTypeConverter";
70 explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
71 : OpRewritePattern<memref::AllocOp>(analysisRoot->
getContext()),
72 dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
74 LogicalResult matchAndRewrite(memref::AllocOp op,
75 PatternRewriter &rewriter)
const override {
77 rewriter, op, [
this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
78 MemRefType type = alloc.getMemref().getType();
79 if (!type.hasStaticShape())
82 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
83 int64_t elementSize = dataLayout.
getTypeSize(type.getElementType());
84 return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
89 DataLayoutAnalysis dataLayoutAnalysis;
94void transform::ApplyAllocToAllocaOp::populatePatterns(
97void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
99 patterns.insert<AllocToAllocaPattern>(
103void transform::ApplyExpandOpsPatternsOp::populatePatterns(
108void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
113void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
118void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
123void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
142 for (
auto *op : allocaOps) {
143 auto alloca = cast<memref::AllocaOp>(op);
147 memref::GlobalOp globalOp;
151 assert(symbolTableOp &&
"expected alloca payload to be in symbol table");
155 Type resultType = alloca.getResult().getType();
158 globalOp = memref::GlobalOp::create(
159 builder, loc, StringAttr::get(ctx,
"alloca"),
160 StringAttr::get(ctx,
"private"), TypeAttr::get(resultType),
162 symbolTable.insert(globalOp);
169 alloca, globalOp.getType(), globalOp.getName());
171 globalOps.push_back(globalOp);
172 getGlobalOps.push_back(getGlobalOp);
176 results.
set(cast<OpResult>(getGlobal()), globalOps);
177 results.
set(cast<OpResult>(getGetGlobal()), getGlobalOps);
182void transform::MemRefAllocaToGlobalOp::getEffects(
199 bool canApplyMultiBuffer =
true;
200 auto target = cast<memref::AllocOp>(op);
201 LLVM_DEBUG(
DBGS() <<
"Start multibuffer transform op: " <<
target <<
"\n";);
204 if (isa<memref::DeallocOp>(user))
206 auto loop = user->getParentOfType<LoopLikeOpInterface>();
208 LLVM_DEBUG(
DBGS() <<
"--allocation not used in a loop\n";
209 DBGS() <<
"----due to user: " << *user;);
210 canApplyMultiBuffer =
false;
214 if (!canApplyMultiBuffer) {
215 LLVM_DEBUG(
DBGS() <<
"--cannot apply multibuffering -> Skip\n";);
223 LLVM_DEBUG(
DBGS() <<
"--op failed to multibuffer\n";);
225 <<
"op failed to multibuffer";
228 results.push_back(*newBuffer);
230 transformResults.
set(cast<OpResult>(getResult()), results);
239transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
244 vector::transferOpflowOpt(rewriter,
target);
249void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
254void transform::MemRefEraseDeadAllocAndStoresOp::build(
OpBuilder &builder,
271 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
275 <<
"could not find " << i
276 <<
"-th enclosing loop";
277 diag.attachNote(
target->getLoc()) <<
"target op";
280 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
285 if (
auto allocaOp = dyn_cast<memref::AllocaOp>(
target)) {
289 <<
"unsupported target op";
290 diag.attachNote(
target->getLoc()) <<
"target op";
295 emitSilenceableError() <<
"could not make target op loop-independent";
296 diag.attachNote(
target->getLoc()) <<
"target op";
308class MemRefTransformDialectExtension
310 MemRefTransformDialectExtension> {
317 declareGeneratedDialect<affine::AffineDialect>();
318 declareGeneratedDialect<arith::ArithDialect>();
319 declareGeneratedDialect<memref::MemRefDialect>();
320 declareGeneratedDialect<nvgpu::NVGPUDialect>();
321 declareGeneratedDialect<vector::VectorDialect>();
323 registerTransformOps<
325#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
331#define GET_OP_CLASSES
332#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
static uint64_t getIndexBitwidth(DataLayoutEntryListRef params)
Returns the bitwidth of the index type if specified in the param list.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Attributes are known-constant values of operations.
MLIRContext * getContext() const
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
Track temporary allocations that are never read from.
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
void registerTransformDialectExtension(DialectRegistry ®istry)
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.