MLIR  18.0.0git
IndependenceTransforms.cpp
Go to the documentation of this file.
1 //===- IndependenceTransforms.cpp - Make ops independent of values --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 
17 using namespace mlir;
18 using namespace mlir::memref;
19 
20 /// Make the given OpFoldResult independent of all independencies.
22  OpFoldResult ofr,
23  ValueRange independencies) {
24  if (ofr.is<Attribute>())
25  return ofr;
26  Value value = ofr.get<Value>();
27  AffineMap boundMap;
28  ValueDimList mapOperands;
30  boundMap, mapOperands, presburger::BoundType::UB, value,
31  /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
32  return failure();
33  return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
34 }
35 
37  memref::AllocaOp allocaOp,
38  ValueRange independencies) {
40  b.setInsertionPoint(allocaOp);
41  Location loc = allocaOp.getLoc();
42 
44  for (OpFoldResult ofr : allocaOp.getMixedSizes()) {
45  auto ub = makeIndependent(b, loc, ofr, independencies);
46  if (failed(ub))
47  return failure();
48  newSizes.push_back(*ub);
49  }
50 
51  // Return existing memref::AllocaOp if nothing has changed.
52  if (llvm::equal(allocaOp.getMixedSizes(), newSizes))
53  return allocaOp.getResult();
54 
55  // Create a new memref::AllocaOp.
56  Value newAllocaOp =
57  b.create<AllocaOp>(loc, newSizes, allocaOp.getType().getElementType());
58 
59  // Create a memref::SubViewOp.
60  SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0));
61  SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1));
62  return b
63  .create<SubViewOp>(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(),
64  strides)
65  .getResult();
66 }
67 
68 /// Push down an UnrealizedConversionCastOp past a SubViewOp.
69 static UnrealizedConversionCastOp
71  UnrealizedConversionCastOp conversionOp, SubViewOp op) {
72  OpBuilder::InsertionGuard g(rewriter);
73  rewriter.setInsertionPoint(op);
74  auto newResultType = cast<MemRefType>(SubViewOp::inferRankReducedResultType(
75  op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(),
76  op.getMixedSizes(), op.getMixedStrides()));
77  Value newSubview = rewriter.create<SubViewOp>(
78  op.getLoc(), newResultType, conversionOp.getOperand(0),
79  op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides());
80  auto newConversionOp = rewriter.create<UnrealizedConversionCastOp>(
81  op.getLoc(), op.getType(), newSubview);
82  rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0));
83  return newConversionOp;
84 }
85 
86 /// Given an original op and a new, modified op with the same number of results,
87 /// whose memref return types may differ, replace all uses of the original op
88 /// with the new op and propagate the new memref types through the IR.
89 ///
90 /// Example:
91 /// %from = memref.alloca(%sz) : memref<?xf32>
92 /// %to = memref.subview ... : ... to memref<?xf32, strided<[1], offset: ?>>
93 /// memref.store %cst, %from[%c0] : memref<?xf32>
94 ///
95 /// In the above example, all uses of %from are replaced with %to. This can be
96 /// done directly for ops such as memref.store. For ops that have memref results
97 /// (e.g., memref.subview), the result type may depend on the operand type, so
98 /// we cannot just replace all uses. There is special handling for common memref
99 /// ops. For all other ops, unrealized_conversion_cast is inserted.
101  Operation *from, Operation *to) {
102  assert(from->getNumResults() == to->getNumResults() &&
103  "expected same number of results");
104  OpBuilder::InsertionGuard g(rewriter);
105  rewriter.setInsertionPointAfter(to);
106 
107  // Wrap new results in unrealized_conversion_cast and replace all uses of the
108  // original op.
109  SmallVector<UnrealizedConversionCastOp> unrealizedConversions;
110  for (const auto &it :
111  llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) {
112  unrealizedConversions.push_back(rewriter.create<UnrealizedConversionCastOp>(
113  to->getLoc(), std::get<0>(it.value()).getType(),
114  std::get<1>(it.value())));
115  rewriter.replaceAllUsesWith(from->getResult(it.index()),
116  unrealizedConversions.back()->getResult(0));
117  }
118 
119  // Push unrealized_conversion_cast ops further down in the IR. I.e., try to
120  // wrap results instead of operands in a cast.
121  for (int i = 0; i < static_cast<int>(unrealizedConversions.size()); ++i) {
122  UnrealizedConversionCastOp conversion = unrealizedConversions[i];
123  assert(conversion->getNumOperands() == 1 &&
124  conversion->getNumResults() == 1 &&
125  "expected single operand and single result");
126  SmallVector<Operation *> users = llvm::to_vector(conversion->getUsers());
127  for (Operation *user : users) {
128  // Handle common memref dialect ops that produce new memrefs and must
129  // be recreated with the new result type.
130  if (auto subviewOp = dyn_cast<SubViewOp>(user)) {
131  unrealizedConversions.push_back(
132  propagateSubViewOp(rewriter, conversion, subviewOp));
133  continue;
134  }
135 
136  // TODO: Other memref ops such as memref.collapse_shape/expand_shape
137  // should also be handled here.
138 
139  // Skip any ops that produce MemRef result or have MemRef region block
140  // arguments. These may need special handling (e.g., scf.for).
141  if (llvm::any_of(user->getResultTypes(),
142  [](Type t) { return isa<MemRefType>(t); }))
143  continue;
144  if (llvm::any_of(user->getRegions(), [](Region &r) {
145  return llvm::any_of(r.getArguments(), [](BlockArgument bbArg) {
146  return isa<MemRefType>(bbArg.getType());
147  });
148  }))
149  continue;
150 
151  // For all other ops, we assume that we can directly replace the operand.
152  // This may have to be revised in the future; e.g., there may be ops that
153  // do not support non-identity layout maps.
154  for (OpOperand &operand : user->getOpOperands()) {
155  if ([[maybe_unused]] auto castOp =
156  operand.get().getDefiningOp<UnrealizedConversionCastOp>()) {
157  rewriter.updateRootInPlace(
158  user, [&]() { operand.set(conversion->getOperand(0)); });
159  }
160  }
161  }
162  }
163 
164  // Erase all unrealized_conversion_cast ops without uses.
165  for (auto op : unrealizedConversions)
166  if (op->getUses().empty())
167  rewriter.eraseOp(op);
168 }
169 
171  memref::AllocaOp allocaOp,
172  ValueRange independencies) {
173  auto replacement =
174  memref::buildIndependentOp(rewriter, allocaOp, independencies);
175  if (failed(replacement))
176  return failure();
177  replaceAndPropagateMemRefType(rewriter, allocaOp,
178  replacement->getDefiningOp());
179  return replacement;
180 }
181 
182 memref::AllocaOp memref::allocToAlloca(
183  RewriterBase &rewriter, memref::AllocOp alloc,
184  function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) {
185  memref::DeallocOp dealloc = nullptr;
186  for (Operation &candidate :
187  llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) {
188  dealloc = dyn_cast<memref::DeallocOp>(candidate);
189  if (dealloc && dealloc.getMemref() == alloc.getMemref() &&
190  (!filter || filter(alloc, dealloc))) {
191  break;
192  }
193  }
194 
195  if (!dealloc)
196  return nullptr;
197 
198  OpBuilder::InsertionGuard guard(rewriter);
199  rewriter.setInsertionPoint(alloc);
200  auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>(
201  alloc, alloc.getMemref().getType(), alloc.getOperands());
202  rewriter.eraseOp(dealloc);
203  return alloca;
204 }
static void replaceAndPropagateMemRefType(RewriterBase &rewriter, Operation *from, Operation *to)
Given an original op and a new, modified op with the same number of results, whose memref return type...
static FailureOr< OpFoldResult > makeIndependent(OpBuilder &b, Location loc, OpFoldResult ofr, ValueRange independencies)
Make the given OpFoldResult independent of all independencies.
static UnrealizedConversionCastOp propagateSubViewOp(RewriterBase &rewriter, UnrealizedConversionCastOp conversionOp, SubViewOp op)
Push down an UnrealizedConversionCastOp past a SubViewOp.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:261
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_range getResults()
Definition: Operation.h:410
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:825
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static LogicalResult computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type, Value value, std::optional< int64_t > dim, ValueRange independencies, bool closedUB=false)
Compute a bound in that is independent of all values in independencies.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
OpFoldResult materializeComputedBound(OpBuilder &b, Location loc, AffineMap boundMap, ArrayRef< std::pair< Value, std::optional< int64_t >>> mapOperands)
Materialize an already computed bound with Affine dialect ops.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
FailureOr< Value > buildIndependentOp(OpBuilder &b, AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
FailureOr< Value > replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocaOp, ValueRange independencies)
Build a new memref::AllocaOp whose dynamic sizes are independent of all given independencies.
memref::AllocaOp allocToAlloca(RewriterBase &rewriter, memref::AllocOp alloc, function_ref< bool(memref::AllocOp, memref::DeallocOp)> filter=nullptr)
Replaces the given alloc with the corresponding alloca and returns it if the following conditions are...
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72