MLIR  19.0.0git
MemRefTransformOps.cpp
Go to the documentation of this file.
1 //===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
28 #include "llvm/Support/Debug.h"
29 
30 using namespace mlir;
31 
32 #define DEBUG_TYPE "memref-transforms"
33 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
34 
35 //===----------------------------------------------------------------------===//
36 // Apply...ConversionPatternsOp
37 //===----------------------------------------------------------------------===//
38 
39 std::unique_ptr<TypeConverter>
40 transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
42  options.allocLowering =
45  options.useGenericFunctions = getUseGenericFunctions();
46 
48  options.overrideIndexBitwidth(getIndexBitwidth());
49 
50  // TODO: the following two options don't really make sense for
51  // memref_to_llvm_type_converter specifically but we should have a single
52  // to_llvm_type_converter.
53  if (getDataLayout().has_value())
54  options.dataLayout = llvm::DataLayout(getDataLayout().value());
55  options.useBarePtrCallConv = getUseBarePtrCallConv();
56 
57  return std::make_unique<LLVMTypeConverter>(getContext(), options);
58 }
59 
60 StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
61  return "LLVMTypeConverter";
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // Apply...PatternsOp
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
70 public:
71  explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
72  : OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
73  dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
74 
75  LogicalResult matchAndRewrite(memref::AllocOp op,
76  PatternRewriter &rewriter) const override {
78  rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
79  MemRefType type = alloc.getMemref().getType();
80  if (!type.hasStaticShape())
81  return false;
82 
83  const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
84  int64_t elementSize = dataLayout.getTypeSize(type.getElementType());
85  return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
86  }));
87  }
88 
89 private:
90  DataLayoutAnalysis dataLayoutAnalysis;
91  int64_t maxSize;
92 };
93 } // namespace
94 
95 void transform::ApplyAllocToAllocaOp::populatePatterns(
96  RewritePatternSet &patterns) {}
97 
98 void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
99  RewritePatternSet &patterns, transform::TransformState &state) {
100  patterns.insert<AllocToAllocaPattern>(
101  state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
102 }
103 
104 void transform::ApplyExpandOpsPatternsOp::populatePatterns(
105  RewritePatternSet &patterns) {
107 }
108 
109 void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
110  RewritePatternSet &patterns) {
112 }
113 
114 void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
115  RewritePatternSet &patterns) {
117 }
118 
119 void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
120  RewritePatternSet &patterns) {
122 }
123 
124 void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
125  populatePatterns(RewritePatternSet &patterns) {
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // AllocaToGlobalOp
131 //===----------------------------------------------------------------------===//
132 
134 transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
136  transform::TransformState &state) {
137  auto allocaOps = state.getPayloadOps(getAlloca());
138 
141 
142  // Transform `memref.alloca`s.
143  for (auto *op : allocaOps) {
144  auto alloca = cast<memref::AllocaOp>(op);
145  MLIRContext *ctx = rewriter.getContext();
146  Location loc = alloca->getLoc();
147 
148  memref::GlobalOp globalOp;
149  {
150  // Find nearest symbol table.
151  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
152  assert(symbolTableOp && "expected alloca payload to be in symbol table");
153  SymbolTable symbolTable(symbolTableOp);
154 
155  // Insert a `memref.global` into the symbol table.
156  Type resultType = alloca.getResult().getType();
157  OpBuilder builder(rewriter.getContext());
158  // TODO: Add a better builder for this.
159  globalOp = builder.create<memref::GlobalOp>(
160  loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
161  TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
162  symbolTable.insert(globalOp);
163  }
164 
165  // Replace the `memref.alloca` with a `memref.get_global` accessing the
166  // global symbol inserted above.
167  rewriter.setInsertionPoint(alloca);
168  auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
169  alloca, globalOp.getType(), globalOp.getName());
170 
171  globalOps.push_back(globalOp);
172  getGlobalOps.push_back(getGlobalOp);
173  }
174 
175  // Assemble results.
176  results.set(cast<OpResult>(getGlobal()), globalOps);
177  results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
178 
180 }
181 
182 void transform::MemRefAllocaToGlobalOp::getEffects(
184  producesHandle(getOperation()->getOpResults(), effects);
185  consumesHandle(getAllocaMutable(), effects);
186  modifiesPayload(effects);
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // MemRefMultiBufferOp
191 //===----------------------------------------------------------------------===//
192 
193 DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
195  transform::TransformResults &transformResults,
196  transform::TransformState &state) {
197  SmallVector<Operation *> results;
198  for (Operation *op : state.getPayloadOps(getTarget())) {
199  bool canApplyMultiBuffer = true;
200  auto target = cast<memref::AllocOp>(op);
201  LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
202  // Skip allocations not used in a loop.
203  for (Operation *user : target->getUsers()) {
204  if (isa<memref::DeallocOp>(user))
205  continue;
206  auto loop = user->getParentOfType<LoopLikeOpInterface>();
207  if (!loop) {
208  LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
209  DBGS() << "----due to user: " << *user;);
210  canApplyMultiBuffer = false;
211  break;
212  }
213  }
214  if (!canApplyMultiBuffer) {
215  LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
216  continue;
217  }
218 
219  auto newBuffer =
220  memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis());
221 
222  if (failed(newBuffer)) {
223  LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
224  return emitSilenceableFailure(target->getLoc())
225  << "op failed to multibuffer";
226  }
227 
228  results.push_back(*newBuffer);
229  }
230  transformResults.set(cast<OpResult>(getResult()), results);
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // MemRefEraseDeadAllocAndStoresOp
236 //===----------------------------------------------------------------------===//
237 
239 transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
240  transform::TransformRewriter &rewriter, Operation *target,
242  transform::TransformState &state) {
243  // Apply store to load forwarding and dead store elimination.
244  vector::transferOpflowOpt(rewriter, target);
245  memref::eraseDeadAllocAndStores(rewriter, target);
247 }
248 
249 void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
251  transform::onlyReadsHandle(getTargetMutable(), effects);
253 }
254 void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
255  OperationState &result,
256  Value target) {
257  result.addOperands(target);
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // MemRefMakeLoopIndependentOp
262 //===----------------------------------------------------------------------===//
263 
264 DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne(
265  transform::TransformRewriter &rewriter, Operation *target,
267  transform::TransformState &state) {
268  // Gather IVs.
269  SmallVector<Value> ivs;
270  Operation *nextOp = target;
271  for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
272  nextOp = nextOp->getParentOfType<scf::ForOp>();
273  if (!nextOp) {
274  DiagnosedSilenceableFailure diag = emitSilenceableError()
275  << "could not find " << i
276  << "-th enclosing loop";
277  diag.attachNote(target->getLoc()) << "target op";
278  return diag;
279  }
280  ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
281  }
282 
283  // Rewrite IR.
284  FailureOr<Value> replacement = failure();
285  if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
286  replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs);
287  } else {
288  DiagnosedSilenceableFailure diag = emitSilenceableError()
289  << "unsupported target op";
290  diag.attachNote(target->getLoc()) << "target op";
291  return diag;
292  }
293  if (failed(replacement)) {
295  emitSilenceableError() << "could not make target op loop-independent";
296  diag.attachNote(target->getLoc()) << "target op";
297  return diag;
298  }
299  results.push_back(replacement->getDefiningOp());
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // Transform op registration
305 //===----------------------------------------------------------------------===//
306 
307 namespace {
308 class MemRefTransformDialectExtension
310  MemRefTransformDialectExtension> {
311 public:
312  using Base::Base;
313 
314  void init() {
315  declareGeneratedDialect<affine::AffineDialect>();
316  declareGeneratedDialect<arith::ArithDialect>();
317  declareGeneratedDialect<memref::MemRefDialect>();
318  declareGeneratedDialect<nvgpu::NVGPUDialect>();
319  declareGeneratedDialect<vector::VectorDialect>();
320 
321  registerTransformOps<
322 #define GET_OP_LIST
323 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
324  >();
325  }
326 };
327 } // namespace
328 
329 #define GET_OP_CLASSES
330 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
331 
333  DialectRegistry &registry) {
334  registry.addExtensions<MemRefTransformDialectExtension>();
335 }
static uint64_t getIndexBitwidth(DataLayoutEntryListRef params)
Returns the bitwidth of the index type if specified in the param list.
static MLIRContext * getContext(OpFoldResult val)
#define DBGS()
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Definition: Builders.h:55
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
llvm::TypeSize getTypeSize(Type t) const
Returns the size of the given type in the current scope.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
@ Malloc
Use malloc for heap allocations.
@ AlignedAlloc
Use aligned_alloc for heap allocations.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:931
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp)
FailureOr< memref::AllocOp > multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis=false)
Transformation to do multi-buffering/array expansion to remove dependencies on the temporary allocati...
Definition: MultiBuffer.cpp:99
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Definition: ExpandOps.cpp:146
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns)
Appends patterns for extracting address computations from the instructions with memory accesses such ...
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns)
Appends patterns for expanding memref operations that modify the metadata (sizes, offset,...
void registerTransformDialectExtension(DialectRegistry &registry)
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void transferOpflowOpt(RewriterBase &rewriter, Operation *rootOp)
Implements transfer op write to read forwarding and dead transfer write optimizations.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)