MLIR  18.0.0git
MemRefTransformOps.cpp
Go to the documentation of this file.
1 //===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
27 #include "llvm/Support/Debug.h"
28 
29 using namespace mlir;
30 
31 #define DEBUG_TYPE "memref-transforms"
32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
33 
34 //===----------------------------------------------------------------------===//
35 // Apply...ConversionPatternsOp
36 //===----------------------------------------------------------------------===//
37 
38 std::unique_ptr<TypeConverter>
39 transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
41  options.allocLowering =
44  options.useGenericFunctions = getUseGenericFunctions();
45 
47  options.overrideIndexBitwidth(getIndexBitwidth());
48 
49  // TODO: the following two options don't really make sense for
50  // memref_to_llvm_type_converter specifically but we should have a single
51  // to_llvm_type_converter.
52  if (getDataLayout().has_value())
53  options.dataLayout = llvm::DataLayout(getDataLayout().value());
54  options.useBarePtrCallConv = getUseBarePtrCallConv();
55 
56  return std::make_unique<LLVMTypeConverter>(getContext(), options);
57 }
58 
59 StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
60  return "LLVMTypeConverter";
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // Apply...PatternsOp
65 //===----------------------------------------------------------------------===//
66 
67 namespace {
68 class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
69 public:
70  explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
71  : OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
72  dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
73 
74  LogicalResult matchAndRewrite(memref::AllocOp op,
75  PatternRewriter &rewriter) const override {
77  rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
78  MemRefType type = alloc.getMemref().getType();
79  if (!type.hasStaticShape())
80  return false;
81 
82  const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
83  int64_t elementSize = dataLayout.getTypeSize(type.getElementType());
84  return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
85  }));
86  }
87 
88 private:
89  DataLayoutAnalysis dataLayoutAnalysis;
90  int64_t maxSize;
91 };
92 } // namespace
93 
94 void transform::ApplyAllocToAllocaOp::populatePatterns(
95  RewritePatternSet &patterns) {}
96 
97 void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
98  RewritePatternSet &patterns, transform::TransformState &state) {
99  patterns.insert<AllocToAllocaPattern>(
100  state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
101 }
102 
103 void transform::ApplyExpandOpsPatternsOp::populatePatterns(
104  RewritePatternSet &patterns) {
106 }
107 
108 void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
109  RewritePatternSet &patterns) {
111 }
112 
113 void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
114  RewritePatternSet &patterns) {
116 }
117 
118 void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
119  RewritePatternSet &patterns) {
121 }
122 
123 void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
124  populatePatterns(RewritePatternSet &patterns) {
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // AllocaToGlobalOp
130 //===----------------------------------------------------------------------===//
131 
133 transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
135  transform::TransformState &state) {
136  auto allocaOps = state.getPayloadOps(getAlloca());
137 
140 
141  // Transform `memref.alloca`s.
142  for (auto *op : allocaOps) {
143  auto alloca = cast<memref::AllocaOp>(op);
144  MLIRContext *ctx = rewriter.getContext();
145  Location loc = alloca->getLoc();
146 
147  memref::GlobalOp globalOp;
148  {
149  // Find nearest symbol table.
150  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
151  assert(symbolTableOp && "expected alloca payload to be in symbol table");
152  SymbolTable symbolTable(symbolTableOp);
153 
154  // Insert a `memref.global` into the symbol table.
155  Type resultType = alloca.getResult().getType();
156  OpBuilder builder(rewriter.getContext());
157  // TODO: Add a better builder for this.
158  globalOp = builder.create<memref::GlobalOp>(
159  loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
160  TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
161  symbolTable.insert(globalOp);
162  }
163 
164  // Replace the `memref.alloca` with a `memref.get_global` accessing the
165  // global symbol inserted above.
166  rewriter.setInsertionPoint(alloca);
167  auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
168  alloca, globalOp.getType(), globalOp.getName());
169 
170  globalOps.push_back(globalOp);
171  getGlobalOps.push_back(getGlobalOp);
172  }
173 
174  // Assemble results.
175  results.set(getGlobal().cast<OpResult>(), globalOps);
176  results.set(getGetGlobal().cast<OpResult>(), getGlobalOps);
177 
179 }
180 
181 void transform::MemRefAllocaToGlobalOp::getEffects(
183  producesHandle(getGlobal(), effects);
184  producesHandle(getGetGlobal(), effects);
185  consumesHandle(getAlloca(), effects);
186  modifiesPayload(effects);
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // MemRefMultiBufferOp
191 //===----------------------------------------------------------------------===//
192 
193 DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
195  transform::TransformResults &transformResults,
196  transform::TransformState &state) {
197  SmallVector<Operation *> results;
198  for (Operation *op : state.getPayloadOps(getTarget())) {
199  bool canApplyMultiBuffer = true;
200  auto target = cast<memref::AllocOp>(op);
201  LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
202  // Skip allocations not used in a loop.
203  for (Operation *user : target->getUsers()) {
204  if (isa<memref::DeallocOp>(user))
205  continue;
206  auto loop = user->getParentOfType<LoopLikeOpInterface>();
207  if (!loop) {
208  LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
209  DBGS() << "----due to user: " << *user;);
210  canApplyMultiBuffer = false;
211  break;
212  }
213  }
214  if (!canApplyMultiBuffer) {
215  LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
216  continue;
217  }
218 
219  auto newBuffer =
220  memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis());
221 
222  if (failed(newBuffer)) {
223  LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
224  return emitSilenceableFailure(target->getLoc())
225  << "op failed to multibuffer";
226  }
227 
228  results.push_back(*newBuffer);
229  }
230  transformResults.set(cast<OpResult>(getResult()), results);
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // MemRefEraseDeadAllocAndStoresOp
236 //===----------------------------------------------------------------------===//
237 
239 transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
240  transform::TransformRewriter &rewriter, Operation *target,
242  transform::TransformState &state) {
243  // Apply store to load forwarding and dead store elimination.
244  vector::transferOpflowOpt(rewriter, target);
245  memref::eraseDeadAllocAndStores(rewriter, target);
247 }
248 
249 void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
251  transform::onlyReadsHandle(getTarget(), effects);
253 }
254 void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
255  OperationState &result,
256  Value target) {
257  result.addOperands(target);
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // MemRefMakeLoopIndependentOp
262 //===----------------------------------------------------------------------===//
263 
264 DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne(
265  transform::TransformRewriter &rewriter, Operation *target,
267  transform::TransformState &state) {
268  // Gather IVs.
269  SmallVector<Value> ivs;
270  Operation *nextOp = target;
271  for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
272  nextOp = nextOp->getParentOfType<scf::ForOp>();
273  if (!nextOp) {
274  DiagnosedSilenceableFailure diag = emitSilenceableError()
275  << "could not find " << i
276  << "-th enclosing loop";
277  diag.attachNote(target->getLoc()) << "target op";
278  return diag;
279  }
280  ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
281  }
282 
283  // Rewrite IR.
284  FailureOr<Value> replacement = failure();
285  if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
286  replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs);
287  } else {
288  DiagnosedSilenceableFailure diag = emitSilenceableError()
289  << "unsupported target op";
290  diag.attachNote(target->getLoc()) << "target op";
291  return diag;
292  }
293  if (failed(replacement)) {
295  emitSilenceableError() << "could not make target op loop-independent";
296  diag.attachNote(target->getLoc()) << "target op";
297  return diag;
298  }
299  results.push_back(replacement->getDefiningOp());
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // Transform op registration
305 //===----------------------------------------------------------------------===//
306 
307 namespace {
308 class MemRefTransformDialectExtension
310  MemRefTransformDialectExtension> {
311 public:
312  using Base::Base;
313 
314  void init() {
315  declareGeneratedDialect<affine::AffineDialect>();
316  declareGeneratedDialect<arith::ArithDialect>();
317  declareGeneratedDialect<memref::MemRefDialect>();
318  declareGeneratedDialect<nvgpu::NVGPUDialect>();
319  declareGeneratedDialect<vector::VectorDialect>();
320 
321  registerTransformOps<
322 #define GET_OP_LIST
323 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
324  >();
325  }
326 };
327 } // namespace
328 
329 #define GET_OP_CLASSES
330 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
331 
333  DialectRegistry &registry) {
334  registry.addExtensions<MemRefTransformDialectExtension>();
335 }
static uint64_t getIndexBitwidth(DataLayoutEntryListRef params)
Returns the bitwidth of the index type if specified in the param list.
static MLIRContext * getContext(OpFoldResult val)
#define DBGS()
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Definition: Builders.h:55
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
This is a value defined by a result of an operation.
Definition: Value.h:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:852
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
Definition: MultiBuffer.cpp:99
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Definition: ExpandOps.cpp:146
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
void registerTransformDialectExtension(DialectRegistry &registry)
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)