MLIR  19.0.0git
MemRefTransformOps.cpp
Go to the documentation of this file.
1 //===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
28 #include "llvm/Support/Debug.h"
29 
30 using namespace mlir;
31 
32 #define DEBUG_TYPE "memref-transforms"
33 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
34 
35 //===----------------------------------------------------------------------===//
36 // Apply...ConversionPatternsOp
37 //===----------------------------------------------------------------------===//
38 
39 std::unique_ptr<TypeConverter>
40 transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
42  options.allocLowering =
45  options.useGenericFunctions = getUseGenericFunctions();
46 
48  options.overrideIndexBitwidth(getIndexBitwidth());
49 
50  // TODO: the following two options don't really make sense for
51  // memref_to_llvm_type_converter specifically but we should have a single
52  // to_llvm_type_converter.
53  if (getDataLayout().has_value())
54  options.dataLayout = llvm::DataLayout(getDataLayout().value());
55  options.useBarePtrCallConv = getUseBarePtrCallConv();
56 
57  return std::make_unique<LLVMTypeConverter>(getContext(), options);
58 }
59 
60 StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
61  return "LLVMTypeConverter";
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // Apply...PatternsOp
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
70 public:
71  explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
72  : OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
73  dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
74 
75  LogicalResult matchAndRewrite(memref::AllocOp op,
76  PatternRewriter &rewriter) const override {
78  rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
79  MemRefType type = alloc.getMemref().getType();
80  if (!type.hasStaticShape())
81  return false;
82 
83  const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
84  int64_t elementSize = dataLayout.getTypeSize(type.getElementType());
85  return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
86  }));
87  }
88 
89 private:
90  DataLayoutAnalysis dataLayoutAnalysis;
91  int64_t maxSize;
92 };
93 } // namespace
94 
95 void transform::ApplyAllocToAllocaOp::populatePatterns(
96  RewritePatternSet &patterns) {}
97 
98 void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
99  RewritePatternSet &patterns, transform::TransformState &state) {
100  patterns.insert<AllocToAllocaPattern>(
101  state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
102 }
103 
104 void transform::ApplyExpandOpsPatternsOp::populatePatterns(
105  RewritePatternSet &patterns) {
107 }
108 
109 void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
110  RewritePatternSet &patterns) {
112 }
113 
114 void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
115  RewritePatternSet &patterns) {
117 }
118 
119 void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
120  RewritePatternSet &patterns) {
122 }
123 
124 void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
125  populatePatterns(RewritePatternSet &patterns) {
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // AllocaToGlobalOp
131 //===----------------------------------------------------------------------===//
132 
134 transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
136  transform::TransformState &state) {
137  auto allocaOps = state.getPayloadOps(getAlloca());
138 
141 
142  // Transform `memref.alloca`s.
143  for (auto *op : allocaOps) {
144  auto alloca = cast<memref::AllocaOp>(op);
145  MLIRContext *ctx = rewriter.getContext();
146  Location loc = alloca->getLoc();
147 
148  memref::GlobalOp globalOp;
149  {
150  // Find nearest symbol table.
151  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
152  assert(symbolTableOp && "expected alloca payload to be in symbol table");
153  SymbolTable symbolTable(symbolTableOp);
154 
155  // Insert a `memref.global` into the symbol table.
156  Type resultType = alloca.getResult().getType();
157  OpBuilder builder(rewriter.getContext());
158  // TODO: Add a better builder for this.
159  globalOp = builder.create<memref::GlobalOp>(
160  loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
161  TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
162  symbolTable.insert(globalOp);
163  }
164 
165  // Replace the `memref.alloca` with a `memref.get_global` accessing the
166  // global symbol inserted above.
167  rewriter.setInsertionPoint(alloca);
168  auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
169  alloca, globalOp.getType(), globalOp.getName());
170 
171  globalOps.push_back(globalOp);
172  getGlobalOps.push_back(getGlobalOp);
173  }
174 
175  // Assemble results.
176  results.set(cast<OpResult>(getGlobal()), globalOps);
177  results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
178 
180 }
181 
182 void transform::MemRefAllocaToGlobalOp::getEffects(
184  producesHandle(getGlobal(), effects);
185  producesHandle(getGetGlobal(), effects);
186  consumesHandle(getAlloca(), effects);
187  modifiesPayload(effects);
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // MemRefMultiBufferOp
192 //===----------------------------------------------------------------------===//
193 
194 DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
196  transform::TransformResults &transformResults,
197  transform::TransformState &state) {
198  SmallVector<Operation *> results;
199  for (Operation *op : state.getPayloadOps(getTarget())) {
200  bool canApplyMultiBuffer = true;
201  auto target = cast<memref::AllocOp>(op);
202  LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
203  // Skip allocations not used in a loop.
204  for (Operation *user : target->getUsers()) {
205  if (isa<memref::DeallocOp>(user))
206  continue;
207  auto loop = user->getParentOfType<LoopLikeOpInterface>();
208  if (!loop) {
209  LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
210  DBGS() << "----due to user: " << *user;);
211  canApplyMultiBuffer = false;
212  break;
213  }
214  }
215  if (!canApplyMultiBuffer) {
216  LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
217  continue;
218  }
219 
220  auto newBuffer =
221  memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis());
222 
223  if (failed(newBuffer)) {
224  LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
225  return emitSilenceableFailure(target->getLoc())
226  << "op failed to multibuffer";
227  }
228 
229  results.push_back(*newBuffer);
230  }
231  transformResults.set(cast<OpResult>(getResult()), results);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // MemRefEraseDeadAllocAndStoresOp
237 //===----------------------------------------------------------------------===//
238 
240 transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
241  transform::TransformRewriter &rewriter, Operation *target,
243  transform::TransformState &state) {
244  // Apply store to load forwarding and dead store elimination.
245  vector::transferOpflowOpt(rewriter, target);
246  memref::eraseDeadAllocAndStores(rewriter, target);
248 }
249 
250 void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
252  transform::onlyReadsHandle(getTarget(), effects);
254 }
255 void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
256  OperationState &result,
257  Value target) {
258  result.addOperands(target);
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // MemRefMakeLoopIndependentOp
263 //===----------------------------------------------------------------------===//
264 
265 DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne(
266  transform::TransformRewriter &rewriter, Operation *target,
268  transform::TransformState &state) {
269  // Gather IVs.
270  SmallVector<Value> ivs;
271  Operation *nextOp = target;
272  for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
273  nextOp = nextOp->getParentOfType<scf::ForOp>();
274  if (!nextOp) {
275  DiagnosedSilenceableFailure diag = emitSilenceableError()
276  << "could not find " << i
277  << "-th enclosing loop";
278  diag.attachNote(target->getLoc()) << "target op";
279  return diag;
280  }
281  ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
282  }
283 
284  // Rewrite IR.
285  FailureOr<Value> replacement = failure();
286  if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
287  replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs);
288  } else {
289  DiagnosedSilenceableFailure diag = emitSilenceableError()
290  << "unsupported target op";
291  diag.attachNote(target->getLoc()) << "target op";
292  return diag;
293  }
294  if (failed(replacement)) {
296  emitSilenceableError() << "could not make target op loop-independent";
297  diag.attachNote(target->getLoc()) << "target op";
298  return diag;
299  }
300  results.push_back(replacement->getDefiningOp());
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // Transform op registration
306 //===----------------------------------------------------------------------===//
307 
308 namespace {
309 class MemRefTransformDialectExtension
311  MemRefTransformDialectExtension> {
312 public:
313  using Base::Base;
314 
315  void init() {
316  declareGeneratedDialect<affine::AffineDialect>();
317  declareGeneratedDialect<arith::ArithDialect>();
318  declareGeneratedDialect<memref::MemRefDialect>();
319  declareGeneratedDialect<nvgpu::NVGPUDialect>();
320  declareGeneratedDialect<vector::VectorDialect>();
321 
322  registerTransformOps<
323 #define GET_OP_LIST
324 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
325  >();
326  }
327 };
328 } // namespace
329 
330 #define GET_OP_CLASSES
331 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
332 
334  DialectRegistry &registry) {
335  registry.addExtensions<MemRefTransformDialectExtension>();
336 }
static uint64_t getIndexBitwidth(DataLayoutEntryListRef params)
Returns the bitwidth of the index type if specified in the param list.
static MLIRContext * getContext(OpFoldResult val)
#define DBGS()
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Definition: Builders.h:55
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
Definition: MultiBuffer.cpp:99
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Definition: ExpandOps.cpp:146
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
void registerTransformDialectExtension(DialectRegistry &registry)
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)