MLIR  16.0.0git
SparseTensorRewriting.cpp
Go to the documentation of this file.
1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensors.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CodegenUtils.h"
14 
21 #include "mlir/IR/AffineMap.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/Support/LLVM.h"
24 
25 using namespace mlir;
26 using namespace mlir::bufferization;
27 using namespace mlir::linalg;
28 using namespace mlir::sparse_tensor;
29 
30 //===---------------------------------------------------------------------===//
31 // Helper methods for the actual rewriting rules.
32 //===---------------------------------------------------------------------===//
33 
34 // Helper to detect a sparse tensor type operand.
35 static bool isSparseTensor(OpOperand *op) {
36  if (auto enc = getSparseTensorEncoding(op->get().getType())) {
38  enc.getDimLevelType();
39  for (auto dimType : dimTypes)
40  if (dimType == SparseTensorEncodingAttr::DimLevelType::Compressed)
41  return true; // at least one compressed
42  }
43  return false;
44 }
45 
46 // Helper method to find zero/uninitialized allocation.
47 static bool isAlloc(OpOperand *op, bool isZero) {
48  Value val = op->get();
49  if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
50  Value copy = alloc.getCopy();
51  if (isZero)
52  return copy && (matchPattern(copy, m_Zero()) ||
53  matchPattern(copy, m_AnyZeroFloat()));
54  return !copy;
55  }
56  return false;
57 }
58 
59 // Helper to detect sampling operation.
60 static bool isSampling(GenericOp op) {
61  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
62  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
63  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
64  // Both scalar input arguments used exactly once.
65  Value s1 = op.getBlock()->getArgument(0);
66  Value s2 = op.getBlock()->getArgument(1);
67  return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
68  (def->getOperand(1) == s1 && def->getOperand(0) == s2);
69  }
70  }
71  return false;
72 }
73 
74 // Helper to detect chain of multiplications that do not involve x.
75 static bool isMulChain(Value val, Value x) {
76  if (auto arg = val.dyn_cast<BlockArgument>())
77  return arg != x;
78  if (auto *def = val.getDefiningOp()) {
79  if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
80  return isMulChain(def->getOperand(0), x) &&
81  isMulChain(def->getOperand(1), x);
82  }
83  return false;
84 }
85 
86 // Helper to detect x = x + <multiplications>.
87 static bool isSumOfMul(GenericOp op) {
88  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
89  if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
90  if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
91  Value x = op.getBlock()->getArguments().back();
92  return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
93  (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
94  }
95  }
96  return false;
97 }
98 
99 // Helper to detect direct yield of a zero value.
100 static bool isZeroYield(GenericOp op) {
101  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102  if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
103  if (arg.getOwner()->getParentOp() == op) {
104  OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()];
105  return matchPattern(t->get(), m_Zero()) ||
107  }
108  } else if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
109  return matchPattern(def, m_Zero()) || matchPattern(def, m_AnyZeroFloat());
110  }
111  return false;
112 }
113 
114 //===---------------------------------------------------------------------===//
115 // The actual sparse tensor rewriting rules.
116 //===---------------------------------------------------------------------===//
117 
118 namespace {
119 
120 /// Rewriting rule that converts direct yield of zero with initial allocation.
121 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
122 public:
124 
125  LogicalResult matchAndRewrite(GenericOp op,
126  PatternRewriter &rewriter) const override {
127  if (!op.hasTensorSemantics() || op.getNumResults() != 1 ||
128  !isAlloc(op.getOutputOperand(0), /*isZero=*/false) || !isZeroYield(op))
129  return failure();
130  auto outputType = op.getResult(0).getType().cast<RankedTensorType>();
131  if (!outputType.hasStaticShape() || getSparseTensorEncoding(outputType))
132  return failure();
133  // Incorporate zero value into allocation copy.
134  Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType());
135  AllocTensorOp a =
136  op.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
137  rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); });
138  rewriter.replaceOp(op, op.getOutputOperand(0)->get());
139  return success();
140  }
141 };
142 
143 /// Rewriting rule that converts two kernels:
144 ///
145 /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
146 /// X(i,j) = S(i,j) * T(i,j)
147 ///
148 /// into a single kernel, using distributive law:
149 ///
150 /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
151 ///
152 /// This kind of fusion (merging two ops into one but using arithmetic
153 /// equalities that may not hold for floating-point computations) would
154 /// be undesirable in the dense case, since we distribute the multiplication
155 /// into the reduction loop. However, for sparse sampling tensor S, such
156 /// a fusion may actually reduce the asymptotic complexity of the kernel,
157 /// since intermediate results may be nullified.
158 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
159 public:
161 
162  LogicalResult matchAndRewrite(GenericOp op,
163  PatternRewriter &rewriter) const override {
164  // Check consumer.
165  if (!op.hasTensorSemantics() || op.getNumInputs() != 2 ||
166  op.getNumResults() != 1 ||
167  op.getNumParallelLoops() != op.getNumLoops() ||
168  !op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() ||
169  !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() ||
170  !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity())
171  return failure();
172  // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
173  // operand can be sparse or dense, since the point of this rewriting rule
174  // is detecting a situation in which *more* sparsity is introduced into
175  // a computation, be it already sparse or still dense.
176  unsigned other = 0;
177  if (isSparseTensor(op.getInputOperand(0)))
178  other = 1;
179  else if (!isSparseTensor(op.getInputOperand(1)))
180  return failure();
181  // Check producer.
182  auto prod = dyn_cast_or_null<GenericOp>(
183  op.getInputOperand(other)->get().getDefiningOp());
184  if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 ||
185  !prod.getResult(0).hasOneUse())
186  return failure();
187  // Sampling consumer and sum of multiplication chain producer.
188  if (!isAlloc(op.getOutputOperand(0), /*isZero=*/false) ||
189  !isAlloc(prod.getOutputOperand(0), /*isZero=*/true) ||
190  !isSampling(op) || !isSumOfMul(prod))
191  return failure();
192  // Modify operand structure of producer and consumer.
193  Location loc = prod.getLoc();
194  SmallVector<Value> inputOps = prod.getInputOperands();
195  SmallVector<Value> outputOps = op.getOutputOperands();
196  SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
197  inputOps.push_back(op.getInputOperand(1 - other)->get());
198  fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
199  // Fuse producer and consumer into a new generic op.
200  auto fusedOp = rewriter.create<GenericOp>(
201  loc, op.getResult(0).getType(), inputOps, outputOps,
202  rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(),
203  /*doc=*/nullptr, /*library_call=*/nullptr);
204  Block &prodBlock = prod.getRegion().front();
205  Block &consBlock = op.getRegion().front();
206  BlockAndValueMapping mapper;
207  Block *fusedBlock = new Block();
208  fusedOp.getRegion().push_back(fusedBlock);
209  unsigned num = prodBlock.getNumArguments();
210  for (unsigned i = 0; i < num - 1; i++)
211  addArg(mapper, fusedBlock, prodBlock.getArgument(i));
212  addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
213  addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
214  // Clone bodies of the producer and consumer in new evaluation order.
215  auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
216  auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
217  rewriter.setInsertionPointToStart(fusedBlock);
218  Value last;
219  for (auto &op : prodBlock.without_terminator())
220  if (&op != acc) {
221  last = op.getResult(0);
222  rewriter.clone(op, mapper);
223  }
224  mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
225  mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
226  last = rewriter.clone(*acc, mapper)->getResult(0);
227  rewriter.create<linalg::YieldOp>(loc, last);
228  // Force initial value on merged allocation for dense outputs.
229  if (!getSparseTensorEncoding(op.getResult(0).getType())) {
230  Value init = prod.getOutputOperand(0)
231  ->get()
232  .getDefiningOp<AllocTensorOp>()
233  .getCopy();
234  AllocTensorOp a =
235  op.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
236  rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); });
237  }
238  // Replace consumer with fused operation. Old producer
239  // and consumer ops will be removed by DCE.
240  rewriter.replaceOp(op, fusedOp->getResults());
241  return success();
242  }
243 
244 private:
245  // Helper to add argument and record the mapping.
246  static void addArg(BlockAndValueMapping &mapper, Block *b, BlockArgument a) {
247  mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
248  }
249 };
250 
251 /// Sparse rewriting rule for reshape operator.
252 template <typename ReshapeOp>
253 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
254 public:
256 
257  LogicalResult matchAndRewrite(ReshapeOp op,
258  PatternRewriter &rewriter) const override {
259  Location loc = op->getLoc();
260  auto encDst = getSparseTensorEncoding(op.getResult().getType());
261  auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
262  // Since a pure dense expansion is very cheap (change of view), for
263  // a sparse2dense or dense2sparse, we can simply unfuse a sparse
264  // conversion from the reshape operation itself.
265  // All other cases are handled elsewhere.
266  if (encDst && encSrc) {
267  return failure();
268  } else if (encSrc) {
269  RankedTensorType rtp =
270  op.getSrc().getType().template cast<RankedTensorType>();
271  auto denseTp =
272  RankedTensorType::get(rtp.getShape(), rtp.getElementType());
273  auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
274  op->setOperand(0, convert);
275  return success();
276  } else if (encDst) {
277  RankedTensorType rtp =
278  op.getResult().getType().template cast<RankedTensorType>();
279  auto denseTp =
280  RankedTensorType::get(rtp.getShape(), rtp.getElementType());
281  auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
282  op.getReassociation());
283  Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
284  rewriter.replaceOp(op, convert);
285  return success();
286  }
287  return failure();
288  }
289 };
290 
291 } // namespace
292 
293 //===---------------------------------------------------------------------===//
294 // Methods that add patterns described in this file to a pattern list.
295 //===---------------------------------------------------------------------===//
296 
298  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd,
299  ReshapeRewriter<tensor::ExpandShapeOp>,
300  ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
301 }
Include the generated interface declarations.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation & back()
Definition: Block.h:143
static bool isZeroYield(GenericOp op)
Block represents an ordered list of Operations.
Definition: Block.h:29
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:492
Value getOperand(unsigned idx)
Definition: Operation.h:267
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool isSampling(GenericOp op)
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:109
Operation & front()
Definition: Block.h:144
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:83
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
void map(Block *from, Block *to)
Inserts a new mapping for &#39;from&#39; to &#39;to&#39;.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:282
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero...
Definition: Matchers.h:271
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
U dyn_cast() const
Definition: Value.h:100
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:309
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class represents an argument of a Block.
Definition: Value.h:300
void setOperand(unsigned idx, Value value)
Definition: Operation.h:268
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static bool isSumOfMul(GenericOp op)
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:377
Type getType() const
Return the type of this value.
Definition: Value.h:118
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:332
void populateSparseTensorRewriting(RewritePatternSet &patterns)
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class represents an operand of an operation.
Definition: Value.h:251
static bool isMulChain(Value val, Value x)
static bool isSparseTensor(OpOperand *op)
static bool isAlloc(OpOperand *op, bool isZero)
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
static bool isZero(OpFoldResult v)
Definition: Tiling.cpp:40
Location getLoc() const
Return the location for this argument.
Definition: Value.h:315
MLIRContext * getContext() const