MLIR  19.0.0git
BufferizationTransformOps.cpp
Go to the documentation of this file.
1 //===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 
22 using namespace mlir;
23 using namespace mlir::bufferization;
24 using namespace mlir::transform;
25 
26 //===----------------------------------------------------------------------===//
27 // BufferLoopHoistingOp
28 //===----------------------------------------------------------------------===//
29 
30 DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne(
31  TransformRewriter &rewriter, Operation *target,
32  ApplyToEachResultList &results, TransformState &state) {
35 }
36 
37 void transform::BufferLoopHoistingOp::getEffects(
39  onlyReadsHandle(getTargetMutable(), effects);
40  modifiesPayload(effects);
41 }
42 
43 //===----------------------------------------------------------------------===//
44 // OneShotBufferizeOp
45 //===----------------------------------------------------------------------===//
46 
48  if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
49  return emitOpError() << "unsupported memcpy op";
50  if (getPrintConflicts() && !getTestAnalysisOnly())
51  return emitOpError() << "'print_conflicts' requires 'test_analysis_only'";
52  if (getDumpAliasSets() && !getTestAnalysisOnly())
53  return emitOpError() << "'dump_alias_sets' requires 'test_analysis_only'";
54  return success();
55 }
56 
58 transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
59  TransformResults &transformResults,
60  TransformState &state) {
62  options.allowReturnAllocsFromLoops = getAllowReturnAllocsFromLoops();
63  options.allowUnknownOps = getAllowUnknownOps();
64  options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
65  options.dumpAliasSets = getDumpAliasSets();
66  options.testAnalysisOnly = getTestAnalysisOnly();
67  options.printConflicts = getPrintConflicts();
68  if (getFunctionBoundaryTypeConversion().has_value())
69  options.setFunctionBoundaryTypeConversion(
70  *getFunctionBoundaryTypeConversion());
71  if (getMemcpyOp() == "memref.copy") {
72  options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
73  b.create<memref::CopyOp>(loc, from, to);
74  return success();
75  };
76  } else if (getMemcpyOp() == "linalg.copy") {
77  options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
78  b.create<linalg::CopyOp>(loc, from, to);
79  return success();
80  };
81  } else {
82  llvm_unreachable("invalid copy op");
83  }
84 
85  auto payloadOps = state.getPayloadOps(getTarget());
86  for (Operation *target : payloadOps) {
87  if (!isa<ModuleOp, FunctionOpInterface>(target))
88  return emitSilenceableError() << "expected module or function target";
89  auto moduleOp = dyn_cast<ModuleOp>(target);
90  if (options.bufferizeFunctionBoundaries) {
91  if (!moduleOp)
92  return emitSilenceableError() << "expected module target";
94  return emitSilenceableError() << "bufferization failed";
95  } else {
97  return emitSilenceableError() << "bufferization failed";
98  }
99  }
100 
101  // This transform op is currently restricted to ModuleOps and function ops.
102  // Such ops are modified in-place.
103  transformResults.set(cast<OpResult>(getTransformed()), payloadOps);
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // EliminateEmptyTensorsOp
109 //===----------------------------------------------------------------------===//
110 
111 void transform::EliminateEmptyTensorsOp::getEffects(
113  onlyReadsHandle(getTargetMutable(), effects);
114  modifiesPayload(effects);
115 }
116 
117 DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
118  transform::TransformRewriter &rewriter, TransformResults &transformResults,
119  TransformState &state) {
120  for (Operation *target : state.getPayloadOps(getTarget())) {
121  if (failed(bufferization::eliminateEmptyTensors(rewriter, target)))
122  return mlir::emitSilenceableFailure(target->getLoc())
123  << "empty tensor elimination failed";
124  }
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // EmptyTensorToAllocTensorOp
130 //===----------------------------------------------------------------------===//
131 
132 DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne(
133  transform::TransformRewriter &rewriter, tensor::EmptyOp target,
135  rewriter.setInsertionPoint(target);
136  auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
137  target, target.getType(), target.getDynamicSizes());
138  results.push_back(alloc);
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // Transform op registration
144 //===----------------------------------------------------------------------===//
145 
146 namespace {
147 /// Registers new ops and declares PDL as dependent dialect since the additional
148 /// ops are using PDL types for operands and results.
149 class BufferizationTransformDialectExtension
151  BufferizationTransformDialectExtension> {
152 public:
153  using Base::Base;
154 
155  void init() {
156  declareGeneratedDialect<bufferization::BufferizationDialect>();
157  declareGeneratedDialect<memref::MemRefDialect>();
158 
159  registerTransformOps<
160 #define GET_OP_LIST
161 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
162  >();
163  }
164 };
165 } // namespace
166 
167 #define GET_OP_CLASSES
168 #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
169 
170 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
171 
173  DialectRegistry &registry) {
174  registry.addExtensions<BufferizationTransformDialectExtension>();
175 }
static llvm::ManagedStatic< PassManagerOptions > options
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
void registerTransformDialectExtension(DialectRegistry &registry)
void hoistBuffersFromLoops(Operation *op)
Within the given operation, hoist buffers from loops where possible.
LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for analysis-enabled bufferization.