MLIR 22.0.0git
IndependenceTransforms.cpp
Go to the documentation of this file.
1//===- IndependenceTransforms.cpp - Make ops independent of values --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14
15using namespace mlir;
16using namespace mlir::memref;
17
18/// Make the given OpFoldResult independent of all independencies.
19static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
20 OpFoldResult ofr,
21 ValueRange independencies) {
22 if (isa<Attribute>(ofr))
23 return ofr;
24 AffineMap boundMap;
25 ValueDimList mapOperands;
27 boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
28 /*closedUB=*/true)))
29 return failure();
30 return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
31}
32
33FailureOr<Value> memref::buildIndependentOp(OpBuilder &b,
34 memref::AllocaOp allocaOp,
35 ValueRange independencies) {
37 b.setInsertionPoint(allocaOp);
38 Location loc = allocaOp.getLoc();
39
41 for (OpFoldResult ofr : allocaOp.getMixedSizes()) {
42 auto ub = makeIndependent(b, loc, ofr, independencies);
43 if (failed(ub))
44 return failure();
45 newSizes.push_back(*ub);
46 }
47
48 // Return existing memref::AllocaOp if nothing has changed.
49 if (llvm::equal(allocaOp.getMixedSizes(), newSizes))
50 return allocaOp.getResult();
51
52 // Create a new memref::AllocaOp.
53 Value newAllocaOp =
54 AllocaOp::create(b, loc, newSizes, allocaOp.getType().getElementType());
55
56 // Create a memref::SubViewOp.
57 SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0));
58 SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1));
59 return SubViewOp::create(b, loc, newAllocaOp, offsets,
60 allocaOp.getMixedSizes(), strides)
61 .getResult();
62}
63
64/// Push down an UnrealizedConversionCastOp past a SubViewOp.
65static UnrealizedConversionCastOp
67 UnrealizedConversionCastOp conversionOp, SubViewOp op) {
68 OpBuilder::InsertionGuard g(rewriter);
69 rewriter.setInsertionPoint(op);
70 MemRefType newResultType = SubViewOp::inferRankReducedResultType(
71 op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
72 op.getMixedSizes(), op.getMixedStrides());
73 Value newSubview = SubViewOp::create(
74 rewriter, op.getLoc(), newResultType, conversionOp.getOperand(0),
75 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
76 auto newConversionOp = UnrealizedConversionCastOp::create(
77 rewriter, op.getLoc(), op.getType(), newSubview);
78 rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0));
79 return newConversionOp;
80}
81
82/// Given an original op and a new, modified op with the same number of results,
83/// whose memref return types may differ, replace all uses of the original op
84/// with the new op and propagate the new memref types through the IR.
85///
86/// Example:
87/// %from = memref.alloca(%sz) : memref<?xf32>
88/// %to = memref.subview ... : ... to memref<?xf32, strided<[1], offset: ?>>
89/// memref.store %cst, %from[%c0] : memref<?xf32>
90///
91/// In the above example, all uses of %from are replaced with %to. This can be
92/// done directly for ops such as memref.store. For ops that have memref results
93/// (e.g., memref.subview), the result type may depend on the operand type, so
94/// we cannot just replace all uses. There is special handling for common memref
95/// ops. For all other ops, unrealized_conversion_cast is inserted.
97 Operation *from, Operation *to) {
98 assert(from->getNumResults() == to->getNumResults() &&
99 "expected same number of results");
100 OpBuilder::InsertionGuard g(rewriter);
101 rewriter.setInsertionPointAfter(to);
102
103 // Wrap new results in unrealized_conversion_cast and replace all uses of the
104 // original op.
105 SmallVector<UnrealizedConversionCastOp> unrealizedConversions;
106 for (const auto &it :
107 llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) {
108 unrealizedConversions.push_back(UnrealizedConversionCastOp::create(
109 rewriter, to->getLoc(), std::get<0>(it.value()).getType(),
110 std::get<1>(it.value())));
111 rewriter.replaceAllUsesWith(from->getResult(it.index()),
112 unrealizedConversions.back()->getResult(0));
113 }
114
115 // Push unrealized_conversion_cast ops further down in the IR. I.e., try to
116 // wrap results instead of operands in a cast.
117 for (int i = 0; i < static_cast<int>(unrealizedConversions.size()); ++i) {
118 UnrealizedConversionCastOp conversion = unrealizedConversions[i];
119 assert(conversion->getNumOperands() == 1 &&
120 conversion->getNumResults() == 1 &&
121 "expected single operand and single result");
122 SmallVector<Operation *> users = llvm::to_vector(conversion->getUsers());
123 for (Operation *user : users) {
124 // Handle common memref dialect ops that produce new memrefs and must
125 // be recreated with the new result type.
126 if (auto subviewOp = dyn_cast<SubViewOp>(user)) {
127 unrealizedConversions.push_back(
128 propagateSubViewOp(rewriter, conversion, subviewOp));
129 continue;
130 }
131
132 // TODO: Other memref ops such as memref.collapse_shape/expand_shape
133 // should also be handled here.
134
135 // Skip any ops that produce MemRef result or have MemRef region block
136 // arguments. These may need special handling (e.g., scf.for).
137 if (llvm::any_of(user->getResultTypes(),
138 [](Type t) { return isa<MemRefType>(t); }))
139 continue;
140 if (llvm::any_of(user->getRegions(), [](Region &r) {
141 return llvm::any_of(r.getArguments(), [](BlockArgument bbArg) {
142 return isa<MemRefType>(bbArg.getType());
143 });
144 }))
145 continue;
146
147 // For all other ops, we assume that we can directly replace the operand.
148 // This may have to be revised in the future; e.g., there may be ops that
149 // do not support non-identity layout maps.
150 for (OpOperand &operand : user->getOpOperands()) {
151 if ([[maybe_unused]] auto castOp =
152 operand.get().getDefiningOp<UnrealizedConversionCastOp>()) {
153 rewriter.modifyOpInPlace(
154 user, [&]() { operand.set(conversion->getOperand(0)); });
155 }
156 }
157 }
158 }
159
160 // Erase all unrealized_conversion_cast ops without uses.
161 for (auto op : unrealizedConversions)
162 if (op->getUses().empty())
163 rewriter.eraseOp(op);
164}
165
167 memref::AllocaOp allocaOp,
168 ValueRange independencies) {
169 auto replacement =
170 memref::buildIndependentOp(rewriter, allocaOp, independencies);
171 if (failed(replacement))
172 return failure();
173 replaceAndPropagateMemRefType(rewriter, allocaOp,
174 replacement->getDefiningOp());
175 return replacement;
176}
177
178memref::AllocaOp memref::allocToAlloca(
179 RewriterBase &rewriter, memref::AllocOp alloc,
180 function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) {
181 memref::DeallocOp dealloc = nullptr;
182 for (Operation &candidate :
183 llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) {
184 dealloc = dyn_cast<memref::DeallocOp>(candidate);
185 if (dealloc && dealloc.getMemref() == alloc.getMemref() &&
186 (!filter || filter(alloc, dealloc))) {
187 break;
188 }
189 }
190
191 if (!dealloc)
192 return nullptr;
193
194 OpBuilder::InsertionGuard guard(rewriter);
195 rewriter.setInsertionPoint(alloc);
196 auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>(
197 alloc, alloc.getMemref().getType(), alloc.getOperands());
198 rewriter.eraseOp(dealloc);
199 return alloca;
200}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static FailureOr< OpFoldResult > makeIndependent(OpBuilder &b, Location loc, OpFoldResult ofr, ValueRange independencies)
Make the given OpFoldResult independent of all independencies.
static void replaceAndPropagateMemRefType(RewriterBase &rewriter, Operation *from, Operation *to)
Given an original op and a new, modified op with the same number of results, whose memref return type...
static UnrealizedConversionCastOp propagateSubViewOp(RewriterBase &rewriter, UnrealizedConversionCastOp conversionOp, SubViewOp op)
Push down an UnrealizedConversionCastOp past a SubViewOp.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static LogicalResult computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, const Variable &var, ValueRange independencies, bool closedUB=false)
Compute a bound in that is independent of all values in independencies.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
OpFoldResult materializeComputedBound(OpBuilder &b, Location loc, AffineMap boundMap, ArrayRef< std::pair< Value, std::optional< int64_t > > > mapOperands)
Materialize an already computed bound with Affine dialect ops.
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
FailureOr< Value > buildIndependentOp(OpBuilder &b, AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
SmallVector< std::pair< Value, std::optional< int64_t > > > ValueDimList
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152