MLIR 22.0.0git
BufferizationTransformOps.cpp
Go to the documentation of this file.
1//===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21
22using namespace mlir;
23using namespace mlir::bufferization;
24using namespace mlir::transform;
25
26//===----------------------------------------------------------------------===//
27// BufferLoopHoistingOp
28//===----------------------------------------------------------------------===//
29
30DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne(
32 ApplyToEachResultList &results, TransformState &state) {
35}
36
37void transform::BufferLoopHoistingOp::getEffects(
39 onlyReadsHandle(getTargetMutable(), effects);
40 modifiesPayload(effects);
41}
42
43//===----------------------------------------------------------------------===//
44// OneShotBufferizeOp
45//===----------------------------------------------------------------------===//
46
47LogicalResult transform::OneShotBufferizeOp::verify() {
48 if (getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
49 return emitOpError() << "unsupported memcpy op";
50 if (getPrintConflicts() && !getTestAnalysisOnly())
51 return emitOpError() << "'print_conflicts' requires 'test_analysis_only'";
52 if (getDumpAliasSets() && !getTestAnalysisOnly())
53 return emitOpError() << "'dump_alias_sets' requires 'test_analysis_only'";
54 return success();
55}
56
58transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
59 TransformResults &transformResults,
60 TransformState &state) {
62 options.allowReturnAllocsFromLoops = getAllowReturnAllocsFromLoops();
63 options.allowUnknownOps = getAllowUnknownOps();
64 options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
65 options.dumpAliasSets = getDumpAliasSets();
66 options.testAnalysisOnly = getTestAnalysisOnly();
67 options.printConflicts = getPrintConflicts();
68 if (getFunctionBoundaryTypeConversion().has_value())
69 options.setFunctionBoundaryTypeConversion(
70 *getFunctionBoundaryTypeConversion());
71 if (getMemcpyOp() == "memref.copy") {
72 options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
73 memref::CopyOp::create(b, loc, from, to);
74 return success();
75 };
76 } else if (getMemcpyOp() == "linalg.copy") {
77 options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) {
78 linalg::CopyOp::create(b, loc, from, to);
79 return success();
80 };
81 } else {
82 llvm_unreachable("invalid copy op");
83 }
84
85 auto payloadOps = state.getPayloadOps(getTarget());
86 BufferizationState bufferizationState;
87
88 for (Operation *target : payloadOps) {
89 if (!isa<ModuleOp, FunctionOpInterface>(target))
90 return emitSilenceableError() << "expected module or function target";
91 auto moduleOp = dyn_cast<ModuleOp>(target);
92 if (options.bufferizeFunctionBoundaries) {
93 if (!moduleOp)
94 return emitSilenceableError() << "expected module target";
96 bufferizationState)))
97 return emitSilenceableError() << "bufferization failed";
98 } else {
100 bufferizationState)))
101 return emitSilenceableError() << "bufferization failed";
102 }
103 }
104
105 // This transform op is currently restricted to ModuleOps and function ops.
106 // Such ops are modified in-place.
107 transformResults.set(cast<OpResult>(getTransformed()), payloadOps);
109}
110
111//===----------------------------------------------------------------------===//
112// EliminateEmptyTensorsOp
113//===----------------------------------------------------------------------===//
114
115void transform::EliminateEmptyTensorsOp::getEffects(
117 onlyReadsHandle(getTargetMutable(), effects);
118 modifiesPayload(effects);
119}
120
121DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply(
122 transform::TransformRewriter &rewriter, TransformResults &transformResults,
123 TransformState &state) {
124 for (Operation *target : state.getPayloadOps(getTarget())) {
126 return mlir::emitSilenceableFailure(target->getLoc())
127 << "empty tensor elimination failed";
128 }
130}
131
132//===----------------------------------------------------------------------===//
133// EmptyTensorToAllocTensorOp
134//===----------------------------------------------------------------------===//
135
136DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne(
137 transform::TransformRewriter &rewriter, tensor::EmptyOp target,
139 rewriter.setInsertionPoint(target);
140 auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
141 target, target.getType(), target.getDynamicSizes());
142 results.push_back(alloc);
144}
145
146//===----------------------------------------------------------------------===//
147// Transform op registration
148//===----------------------------------------------------------------------===//
149
150namespace {
151/// Registers new ops and declares PDL as dependent dialect since the additional
152/// ops are using PDL types for operands and results.
153class BufferizationTransformDialectExtension
155 BufferizationTransformDialectExtension> {
156public:
158 BufferizationTransformDialectExtension)
159
160 using Base::Base;
161
162 void init() {
163 declareGeneratedDialect<bufferization::BufferizationDialect>();
164 declareGeneratedDialect<memref::MemRefDialect>();
165
166 registerTransformOps<
167#define GET_OP_LIST
168#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
169
170 >();
171 }
172};
173} // namespace
174
175#define GET_OP_CLASSES
176#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
177
178#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.cpp.inc"
179
181 DialectRegistry &registry) {
182 registry.addExtensions<BufferizationTransformDialectExtension>();
183}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
void registerTransformDialectExtension(DialectRegistry &registry)
void hoistBuffersFromLoops(Operation *op)
Within the given operation, hoist buffers from loops where possible.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op)
Try to eliminate "tensor.empty" ops inside op.
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
Options for analysis-enabled bufferization.