28 #include "llvm/ADT/MapVector.h"
29 #include "llvm/ADT/ScopeExit.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/Debug.h"
37 #define DEBUG_TYPE "linalg-fusion"
70 bool fromSubViewOpOnly =
false) {
73 for (
OpOperand &opOperand : op->getOpOperands()) {
80 if (fromSubViewOpOnly &&
81 !isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>(
82 opOperand.get().getDefiningOp()))
85 AffineMap map = op.getMatchingIndexingMap(&opOperand);
86 LLVM_DEBUG(llvm::dbgs() <<
"getShapeDefiningLoopRange I/O idx: "
87 << opOperand.getOperandNumber() <<
"\n");
88 LLVM_DEBUG(llvm::dbgs()
89 <<
"getShapeDefiningLoopRange map: " << map <<
"\n");
91 auto dimExpr = dyn_cast<AffineDimExpr>(en.value());
94 if (loopDepth == cast<AffineDimExpr>(en.value()).getPosition()) {
95 LLVM_DEBUG(llvm::dbgs() <<
"getShapeDefiningLoopRange loopDepth: "
96 << loopDepth <<
"\n");
97 LLVM_DEBUG(llvm::dbgs() <<
"getShapeDefiningLoopRange shape: "
98 << opOperand.get() <<
"\n");
100 static_cast<unsigned>(en.index())};
104 llvm_unreachable(
"Expect to be able to extract a shape defining loop range");
108 return producer->getOperands();
120 for (
unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
124 sizeBounds.push_back(dim);
125 auto it = fusedLoopsAndRanges.find(i);
126 if (it != fusedLoopsAndRanges.end()) {
127 ivs.push_back(it->second.offset);
128 tileSizes.push_back(it->second.size);
129 loopRanges.push_back(it->second);
130 LLVM_DEBUG(llvm::dbgs() <<
"tiled loop#" << i <<
" with LoopRange "
131 << loopRanges.back() <<
"\n");
135 LLVM_DEBUG(llvm::dbgs() <<
"full loop#" << i <<
" with LoopRange "
136 << loopRanges.back() <<
"\n");
141 clonedShapes.reserve(producer->getNumOperands());
151 resultTypes.reserve(producer->getNumResults());
152 int64_t firstInitOperandIdx =
154 for (int64_t i = 0, e = producer->getNumResults(); i < e; ++i) {
155 resultTypes.push_back(clonedShapes[firstInitOperandIdx + i].
getType());
159 LinalgOp clonedOp =
clone(b, producer, resultTypes, clonedShapes);
163 llvm::map_range(loopRanges, [&](
Range range) {
return range.
offset; }));
172 Value shapedOperand,
unsigned dim) {
174 if (
auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp))
175 return subViewOp.getOrCreateRanges(b, loc)[dim];
176 if (
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(shapeProducingOp))
177 return sliceOp.getOrCreateRanges(b, loc)[dim];
178 llvm_unreachable(
"SubviewOp or ExtractSliceOp expected");
186 LLVM_DEBUG(llvm::dbgs() <<
"Producer map: " << producerMap <<
"\n");
188 Value shapedOperand = consumerOpOperand.
get();
190 unsigned posInProducerLoop = cast<AffineDimExpr>(en.value()).getPosition();
192 b, consumerOpOperand.
getOwner()->
getLoc(), shapedOperand, en.index());
194 return fuse(b, producerOp, fusedLoopsAndRanges);
204 if (!isa<RankedTensorType>(tensor.
getType()))
208 LLVM_DEBUG(llvm::dbgs() <<
"\ngetProducerOfTensor: " << tensor);
210 opResult = cast<OpResult>(tensor);
213 if (
auto sliceOp = tensor.
getDefiningOp<tensor::ExtractSliceOp>()) {
214 tensor = sliceOp.getSource();
217 if (
auto blockArg = dyn_cast<BlockArgument>(tensor)) {
218 if (
auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
219 tensor = forOp.getInitArgs()[blockArg.getArgNumber()];
227 FailureOr<FusionInfo>
229 Value inputTensor = consumerOpOperand.
get();
232 if (!producerOpResult) {
233 LLVM_DEBUG(llvm::dbgs() <<
"\nUnable to find producer");
239 FailureOr<FusionInfo>
242 auto producerOp = dyn_cast<LinalgOp>(producerOpResult.
getOwner());
246 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.
getOwner());
250 Value inputTensor = consumerOpOperand.
get();
253 auto sliceOp = inputTensor.
getDefiningOp<tensor::ExtractSliceOp>();
255 LLVM_DEBUG(llvm::dbgs()
256 <<
"\nNot fusable, not an extract_slice op: " << inputTensor);
268 LLVM_DEBUG(llvm::dbgs() <<
"Fuse into consumer: " << *consumerOp <<
"\n");
271 LinalgOp fusedProducer =
272 fuse(b, producerOp, producerOp.getMatchingIndexingMap(opOperand),
280 if (cast<ShapedType>(consumerType).getRank() !=
281 cast<ShapedType>(def.
getType()).getRank()) {
282 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
290 if (consumerType != def.
getType())
291 def = b.
create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
292 consumerOpOperand.
set(def);
static LinalgOp fuse(OpBuilder &b, LinalgOp producer, const DenseMap< unsigned, Range > &fusedLoopsAndRanges)
Fuses the producer by cloning the producer.
static void getProducerOfTensor(Value tensor, OpResult &opResult)
Walk back use-def chain through scf::For yields.
static SmallVector< Value > getTiledOperands(LinalgOp producer)
static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, bool fromSubViewOpOnly=false)
static Range getRangeFromOperandShape(OpBuilder &b, Location loc, Value shapedOperand, unsigned dim)
Get the loop range for a dimension dim based on the shapedOperand.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
IntegerAttr getIndexAttr(int64_t value)
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class provides a mutable adaptor for a range of operands.
OperandRange getAsOperandRange() const
Explicit conversion to an OperandRange.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< FusionInfo > fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand)
This implements the fusion part of the "tileAndFuse on tensors" transformation and thus requires the ...
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value src, const llvm::SmallBitVector &dropDims)
Create tensor.collapse_shape to drop unit dimensions in dropDims in tensor src.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Implements a simple high-level fusion pass on linalg structured operations.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A struct containing the Linalg producer before and after fusion.