MLIR  22.0.0git
SparseVectorization.cpp
Go to the documentation of this file.
1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts loops generated by the sparsifier into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparsifier can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "Utils/CodegenUtils.h"
20 #include "Utils/LoopEmitter.h"
21 
30 #include "mlir/IR/Matchers.h"
31 
32 using namespace mlir;
33 using namespace mlir::sparse_tensor;
34 
35 namespace {
36 
37 /// Target SIMD properties:
38 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
39 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve)
40 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
41 struct VL {
42  unsigned vectorLength;
43  bool enableVLAVectorization;
44  bool enableSIMDIndex32;
45 };
46 
47 /// Helper test for invariant value (defined outside given block).
48 static bool isInvariantValue(Value val, Block *block) {
49  return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
50 }
51 
52 /// Helper test for invariant argument (defined outside given block).
53 static bool isInvariantArg(BlockArgument arg, Block *block) {
54  return arg.getOwner() != block;
55 }
56 
57 /// Constructs vector type for element type.
58 static VectorType vectorType(VL vl, Type etp) {
59  return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
60 }
61 
62 /// Constructs vector type from a memref value.
63 static VectorType vectorType(VL vl, Value mem) {
64  return vectorType(vl, getMemRefType(mem).getElementType());
65 }
66 
67 /// Constructs vector iteration mask.
68 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
69  Value iv, Value lo, Value hi, Value step) {
70  VectorType mtp = vectorType(vl, rewriter.getI1Type());
71  // Special case if the vector length evenly divides the trip count (for
72  // example, "for i = 0, 128, 16"). A constant all-true mask is generated
73  // so that all subsequent masked memory operations are immediately folded
74  // into unconditional memory operations.
75  IntegerAttr loInt, hiInt, stepInt;
76  if (matchPattern(lo, m_Constant(&loInt)) &&
77  matchPattern(hi, m_Constant(&hiInt)) &&
78  matchPattern(step, m_Constant(&stepInt))) {
79  if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
80  Value trueVal = constantI1(rewriter, loc, true);
81  return vector::BroadcastOp::create(rewriter, loc, mtp, trueVal);
82  }
83  }
84  // Otherwise, generate a vector mask that avoids overrunning the upperbound
85  // during vector execution. Here we rely on subsequent loop optimizations to
86  // avoid executing the mask in all iterations, for example, by splitting the
87  // loop into an unconditional vector loop and a scalar cleanup loop.
88  auto min = AffineMap::get(
89  /*dimCount=*/2, /*symbolCount=*/1,
90  {rewriter.getAffineSymbolExpr(0),
91  rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
92  rewriter.getContext());
93  Value end = rewriter.createOrFold<affine::AffineMinOp>(
94  loc, min, ValueRange{hi, iv, step});
95  return vector::CreateMaskOp::create(rewriter, loc, mtp, end);
96 }
97 
98 /// Generates a vectorized invariant. Here we rely on subsequent loop
99 /// optimizations to hoist the invariant broadcast out of the vector loop.
100 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
101  Value val) {
102  VectorType vtp = vectorType(vl, val.getType());
103  return vector::BroadcastOp::create(rewriter, val.getLoc(), vtp, val);
104 }
105 
106 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
107 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
108 /// that the sparsifier can only generate indirect loads in
109 /// the last index, i.e. back().
110 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
111  Value mem, ArrayRef<Value> idxs, Value vmask) {
112  VectorType vtp = vectorType(vl, mem);
113  Value pass = constantZero(rewriter, loc, vtp);
114  if (llvm::isa<VectorType>(idxs.back().getType())) {
115  SmallVector<Value> scalarArgs(idxs);
116  Value indexVec = idxs.back();
117  scalarArgs.back() = constantIndex(rewriter, loc, 0);
118  return vector::GatherOp::create(rewriter, loc, vtp, mem, scalarArgs,
119  indexVec, vmask, pass);
120  }
121  return vector::MaskedLoadOp::create(rewriter, loc, vtp, mem, idxs, vmask,
122  pass);
123 }
124 
125 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
126 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
127 /// that the sparsifier can only generate indirect stores in
128 /// the last index, i.e. back().
129 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
130  ArrayRef<Value> idxs, Value vmask, Value rhs) {
131  if (llvm::isa<VectorType>(idxs.back().getType())) {
132  SmallVector<Value> scalarArgs(idxs);
133  Value indexVec = idxs.back();
134  scalarArgs.back() = constantIndex(rewriter, loc, 0);
135  vector::ScatterOp::create(rewriter, loc, mem, scalarArgs, indexVec, vmask,
136  rhs);
137  return;
138  }
139  vector::MaskedStoreOp::create(rewriter, loc, mem, idxs, vmask, rhs);
140 }
141 
142 /// Detects a vectorizable reduction operations and returns the
143 /// combining kind of reduction on success in `kind`.
144 static bool isVectorizableReduction(Value red, Value iter,
145  vector::CombiningKind &kind) {
146  if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
147  kind = vector::CombiningKind::ADD;
148  return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
149  }
150  if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
151  kind = vector::CombiningKind::ADD;
152  return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
153  }
154  if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
155  kind = vector::CombiningKind::ADD;
156  return subf->getOperand(0) == iter;
157  }
158  if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
159  kind = vector::CombiningKind::ADD;
160  return subi->getOperand(0) == iter;
161  }
162  if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
163  kind = vector::CombiningKind::MUL;
164  return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
165  }
166  if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
167  kind = vector::CombiningKind::MUL;
168  return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
169  }
170  if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
171  kind = vector::CombiningKind::AND;
172  return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
173  }
174  if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
175  kind = vector::CombiningKind::OR;
176  return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
177  }
178  if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
179  kind = vector::CombiningKind::XOR;
180  return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
181  }
182  return false;
183 }
184 
185 /// Generates an initial value for a vector reduction, following the scheme
186 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
187 /// initial scalar value is correctly embedded in the vector reduction value,
188 /// and a straightforward horizontal reduction will complete the operation.
189 /// Value 'r' denotes the initial value of the reduction outside the loop.
190 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
191  Value red, Value iter, Value r,
192  VectorType vtp) {
193  vector::CombiningKind kind;
194  if (!isVectorizableReduction(red, iter, kind))
195  llvm_unreachable("unknown reduction");
196  switch (kind) {
197  case vector::CombiningKind::ADD:
198  case vector::CombiningKind::XOR:
199  // Initialize reduction vector to: | 0 | .. | 0 | r |
200  return vector::InsertOp::create(rewriter, loc, r,
201  constantZero(rewriter, loc, vtp),
202  constantIndex(rewriter, loc, 0));
203  case vector::CombiningKind::MUL:
204  // Initialize reduction vector to: | 1 | .. | 1 | r |
205  return vector::InsertOp::create(rewriter, loc, r,
206  constantOne(rewriter, loc, vtp),
207  constantIndex(rewriter, loc, 0));
208  case vector::CombiningKind::AND:
209  case vector::CombiningKind::OR:
210  // Initialize reduction vector to: | r | .. | r | r |
211  return vector::BroadcastOp::create(rewriter, loc, vtp, r);
212  default:
213  break;
214  }
215  llvm_unreachable("unknown reduction kind");
216 }
217 
218 /// This method is called twice to analyze and rewrite the given subscripts.
219 /// The first call (!codegen) does the analysis. Then, on success, the second
220 /// call (codegen) yields the proper vector form in the output parameter
221 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
222 /// stay in sync. Note that the analyis part is simple because the sparsifier
223 /// only generates relatively simple subscript expressions.
224 ///
225 /// See https://llvm.org/docs/GetElementPtr.html for some background on
226 /// the complications described below.
227 ///
228 /// We need to generate a position/coordinate load from the sparse storage
229 /// scheme. Narrower data types need to be zero extended before casting
230 /// the value into the `index` type used for looping and indexing.
231 ///
232 /// For the scalar case, subscripts simply zero extend narrower indices
233 /// into 64-bit values before casting to an index type without a performance
234 /// penalty. Indices that already are 64-bit, in theory, cannot express the
235 /// full range since the LLVM backend defines addressing in terms of an
236 /// unsigned pointer/signed index pair.
237 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
238  VL vl, ValueRange subs, bool codegen,
239  Value vmask, SmallVectorImpl<Value> &idxs) {
240  unsigned d = 0;
241  unsigned dim = subs.size();
242  Block *block = &forOp.getRegion().front();
243  for (auto sub : subs) {
244  bool innermost = ++d == dim;
245  // Invariant subscripts in outer dimensions simply pass through.
246  // Note that we rely on LICM to hoist loads where all subscripts
247  // are invariant in the innermost loop.
248  // Example:
249  // a[inv][i] for inv
250  if (isInvariantValue(sub, block)) {
251  if (innermost)
252  return false;
253  if (codegen)
254  idxs.push_back(sub);
255  continue; // success so far
256  }
257  // Invariant block arguments (including outer loop indices) in outer
258  // dimensions simply pass through. Direct loop indices in the
259  // innermost loop simply pass through as well.
260  // Example:
261  // a[i][j] for both i and j
262  if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
263  if (isInvariantArg(arg, block) == innermost)
264  return false;
265  if (codegen)
266  idxs.push_back(sub);
267  continue; // success so far
268  }
269  // Look under the hood of casting.
270  auto cast = sub;
271  while (true) {
272  if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
273  cast = icast->getOperand(0);
274  else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
275  cast = ecast->getOperand(0);
276  else
277  break;
278  }
279  // Since the index vector is used in a subsequent gather/scatter
280  // operations, which effectively defines an unsigned pointer + signed
281  // index, we must zero extend the vector to an index width. For 8-bit
282  // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
283  // zero extending the elements into 64-bit loses some performance since
284  // the 32-bit indexed gather/scatter is more efficient than the 64-bit
285  // index variant (if the negative 32-bit index space is unused, the
286  // enableSIMDIndex32 flag can preserve this performance). For 64-bit
287  // values, there is no good way to state that the indices are unsigned,
288  // which creates the potential of incorrect address calculations in the
289  // unlikely case we need such extremely large offsets.
290  // Example:
291  // a[ ind[i] ]
292  if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
293  if (!innermost)
294  return false;
295  if (codegen) {
296  SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
297  Location loc = forOp.getLoc();
298  Value vload =
299  genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
300  Type etp = llvm::cast<VectorType>(vload.getType()).getElementType();
301  if (!llvm::isa<IndexType>(etp)) {
302  if (etp.getIntOrFloatBitWidth() < 32)
303  vload = arith::ExtUIOp::create(
304  rewriter, loc, vectorType(vl, rewriter.getI32Type()), vload);
305  else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
306  vload = arith::ExtUIOp::create(
307  rewriter, loc, vectorType(vl, rewriter.getI64Type()), vload);
308  }
309  idxs.push_back(vload);
310  }
311  continue; // success so far
312  }
313  // Address calculation 'i = add inv, idx' (after LICM).
314  // Example:
315  // a[base + i]
316  if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
317  Value inv = load.getOperand(0);
318  Value idx = load.getOperand(1);
319  // Swap non-invariant.
320  if (!isInvariantValue(inv, block)) {
321  inv = idx;
322  idx = load.getOperand(0);
323  }
324  // Inspect.
325  if (isInvariantValue(inv, block)) {
326  if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
327  if (isInvariantArg(arg, block) || !innermost)
328  return false;
329  if (codegen)
330  idxs.push_back(
331  arith::AddIOp::create(rewriter, forOp.getLoc(), inv, idx));
332  continue; // success so far
333  }
334  }
335  }
336  return false;
337  }
338  return true;
339 }
340 
341 #define UNAOP(xxx) \
342  if (isa<xxx>(def)) { \
343  if (codegen) \
344  vexp = xxx::create(rewriter, loc, vx); \
345  return true; \
346  }
347 
348 #define TYPEDUNAOP(xxx) \
349  if (auto x = dyn_cast<xxx>(def)) { \
350  if (codegen) { \
351  VectorType vtp = vectorType(vl, x.getType()); \
352  vexp = xxx::create(rewriter, loc, vtp, vx); \
353  } \
354  return true; \
355  }
356 
357 #define BINOP(xxx) \
358  if (isa<xxx>(def)) { \
359  if (codegen) \
360  vexp = xxx::create(rewriter, loc, vx, vy); \
361  return true; \
362  }
363 
364 /// This method is called twice to analyze and rewrite the given expression.
365 /// The first call (!codegen) does the analysis. Then, on success, the second
366 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
367 /// This mechanism ensures that analysis and rewriting code stay in sync. Note
368 /// that the analyis part is simple because the sparsifier only generates
369 /// relatively simple expressions inside the for-loops.
370 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
371  Value exp, bool codegen, Value vmask, Value &vexp) {
372  Location loc = forOp.getLoc();
373  // Reject unsupported types.
374  if (!VectorType::isValidElementType(exp.getType()))
375  return false;
376  // A block argument is invariant/reduction/index.
377  if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
378  if (arg == forOp.getInductionVar()) {
379  // We encountered a single, innermost index inside the computation,
380  // such as a[i] = i, which must convert to [i, i+1, ...].
381  if (codegen) {
382  VectorType vtp = vectorType(vl, arg.getType());
383  Value veci = vector::BroadcastOp::create(rewriter, loc, vtp, arg);
384  Value incr = vector::StepOp::create(rewriter, loc, vtp);
385  vexp = arith::AddIOp::create(rewriter, loc, veci, incr);
386  }
387  return true;
388  }
389  // An invariant or reduction. In both cases, we treat this as an
390  // invariant value, and rely on later replacing and folding to
391  // construct a proper reduction chain for the latter case.
392  if (codegen)
393  vexp = genVectorInvariantValue(rewriter, vl, exp);
394  return true;
395  }
396  // Something defined outside the loop-body is invariant.
397  Operation *def = exp.getDefiningOp();
398  Block *block = &forOp.getRegion().front();
399  if (def->getBlock() != block) {
400  if (codegen)
401  vexp = genVectorInvariantValue(rewriter, vl, exp);
402  return true;
403  }
404  // Proper load operations. These are either values involved in the
405  // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
406  // or coordinate values inside the computation that are now fetched from
407  // the sparse storage coordinates arrays, such as a[i] = i becomes
408  // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
409  // and 'hi = lo + vl - 1'.
410  if (auto load = dyn_cast<memref::LoadOp>(def)) {
411  auto subs = load.getIndices();
412  SmallVector<Value> idxs;
413  if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
414  if (codegen)
415  vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
416  return true;
417  }
418  return false;
419  }
420  // Inside loop-body unary and binary operations. Note that it would be
421  // nicer if we could somehow test and build the operations in a more
422  // concise manner than just listing them all (although this way we know
423  // for certain that they can vectorize).
424  //
425  // TODO: avoid visiting CSEs multiple times
426  //
427  if (def->getNumOperands() == 1) {
428  Value vx;
429  if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
430  vx)) {
431  UNAOP(math::AbsFOp)
432  UNAOP(math::AbsIOp)
433  UNAOP(math::CeilOp)
434  UNAOP(math::FloorOp)
435  UNAOP(math::SqrtOp)
436  UNAOP(math::ExpM1Op)
437  UNAOP(math::Log1pOp)
438  UNAOP(math::SinOp)
439  UNAOP(math::TanhOp)
440  UNAOP(arith::NegFOp)
441  TYPEDUNAOP(arith::TruncFOp)
442  TYPEDUNAOP(arith::ExtFOp)
443  TYPEDUNAOP(arith::FPToSIOp)
444  TYPEDUNAOP(arith::FPToUIOp)
445  TYPEDUNAOP(arith::SIToFPOp)
446  TYPEDUNAOP(arith::UIToFPOp)
447  TYPEDUNAOP(arith::ExtSIOp)
448  TYPEDUNAOP(arith::ExtUIOp)
449  TYPEDUNAOP(arith::IndexCastOp)
450  TYPEDUNAOP(arith::TruncIOp)
451  TYPEDUNAOP(arith::BitcastOp)
452  // TODO: complex?
453  }
454  } else if (def->getNumOperands() == 2) {
455  Value vx, vy;
456  if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
457  vx) &&
458  vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
459  vy)) {
460  // We only accept shift-by-invariant (where the same shift factor applies
461  // to all packed elements). In the vector dialect, this is still
462  // represented with an expanded vector at the right-hand-side, however,
463  // so that we do not have to special case the code generation.
464  if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
465  isa<arith::ShRSIOp>(def)) {
466  Value shiftFactor = def->getOperand(1);
467  if (!isInvariantValue(shiftFactor, block))
468  return false;
469  }
470  // Generate code.
471  BINOP(arith::MulFOp)
472  BINOP(arith::MulIOp)
473  BINOP(arith::DivFOp)
474  BINOP(arith::DivSIOp)
475  BINOP(arith::DivUIOp)
476  BINOP(arith::AddFOp)
477  BINOP(arith::AddIOp)
478  BINOP(arith::SubFOp)
479  BINOP(arith::SubIOp)
480  BINOP(arith::AndIOp)
481  BINOP(arith::OrIOp)
482  BINOP(arith::XOrIOp)
483  BINOP(arith::ShLIOp)
484  BINOP(arith::ShRUIOp)
485  BINOP(arith::ShRSIOp)
486  // TODO: complex?
487  }
488  }
489  return false;
490 }
491 
492 #undef UNAOP
493 #undef TYPEDUNAOP
494 #undef BINOP
495 
496 /// This method is called twice to analyze and rewrite the given for-loop.
497 /// The first call (!codegen) does the analysis. Then, on success, the second
498 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
499 /// that analysis and rewriting code stay in sync.
500 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
501  bool codegen) {
502  Block &block = forOp.getRegion().front();
503  // For loops with single yield statement (as below) could be generated
504  // when custom reduce is used with unary operation.
505  // for (...)
506  // yield c_0
507  if (block.getOperations().size() <= 1)
508  return false;
509 
510  Location loc = forOp.getLoc();
511  scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
512  auto &last = *++block.rbegin();
513  scf::ForOp forOpNew;
514 
515  // Perform initial set up during codegen (we know that the first analysis
516  // pass was successful). For reductions, we need to construct a completely
517  // new for-loop, since the incoming and outgoing reduction type
518  // changes into SIMD form. For stores, we can simply adjust the stride
519  // and insert in the existing for-loop. In both cases, we set up a vector
520  // mask for all operations which takes care of confining vectors to
521  // the original iteration space (later cleanup loops or other
522  // optimizations can take care of those).
523  Value vmask;
524  if (codegen) {
525  Value step = constantIndex(rewriter, loc, vl.vectorLength);
526  if (vl.enableVLAVectorization) {
527  Value vscale =
528  vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
529  step = arith::MulIOp::create(rewriter, loc, vscale, step);
530  }
531  if (!yield.getResults().empty()) {
532  Value init = forOp.getInitArgs()[0];
533  VectorType vtp = vectorType(vl, init.getType());
534  Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
535  forOp.getRegionIterArg(0), init, vtp);
536  forOpNew =
537  scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
538  forOp.getUpperBound(), step, vinit,
539  /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
540  forOpNew->setAttr(
542  forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
543  rewriter.setInsertionPointToStart(forOpNew.getBody());
544  } else {
545  rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); });
546  rewriter.setInsertionPoint(yield);
547  }
548  vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
549  forOp.getLowerBound(), forOp.getUpperBound(), step);
550  }
551 
552  // Sparse for-loops either are terminated by a non-empty yield operation
553  // (reduction loop) or otherwise by a store operation (pararallel loop).
554  if (!yield.getResults().empty()) {
555  // Analyze/vectorize reduction.
556  if (yield->getNumOperands() != 1)
557  return false;
558  Value red = yield->getOperand(0);
559  Value iter = forOp.getRegionIterArg(0);
560  vector::CombiningKind kind;
561  Value vrhs;
562  if (isVectorizableReduction(red, iter, kind) &&
563  vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
564  if (codegen) {
565  Value partial = forOpNew.getResult(0);
566  Value vpass = genVectorInvariantValue(rewriter, vl, iter);
567  Value vred = arith::SelectOp::create(rewriter, loc, vmask, vrhs, vpass);
568  scf::YieldOp::create(rewriter, loc, vred);
569  rewriter.setInsertionPointAfter(forOpNew);
570  Value vres = vector::ReductionOp::create(rewriter, loc, kind, partial);
571  // Now do some relinking (last one is not completely type safe
572  // but all bad ones are removed right away). This also folds away
573  // nop broadcast operations.
574  rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
575  rewriter.replaceAllUsesWith(forOp.getInductionVar(),
576  forOpNew.getInductionVar());
577  rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
578  forOpNew.getRegionIterArg(0));
579  rewriter.eraseOp(forOp);
580  }
581  return true;
582  }
583  } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
584  // Analyze/vectorize store operation.
585  auto subs = store.getIndices();
586  SmallVector<Value> idxs;
587  Value rhs = store.getValue();
588  Value vrhs;
589  if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
590  vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
591  if (codegen) {
592  genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
593  rewriter.eraseOp(store);
594  }
595  return true;
596  }
597  }
598 
599  assert(!codegen && "cannot call codegen when analysis failed");
600  return false;
601 }
602 
603 /// Basic for-loop vectorizer.
604 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
605 public:
607 
608  ForOpRewriter(MLIRContext *context, unsigned vectorLength,
609  bool enableVLAVectorization, bool enableSIMDIndex32)
610  : OpRewritePattern(context),
611  vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
612 
613  LogicalResult matchAndRewrite(scf::ForOp op,
614  PatternRewriter &rewriter) const override {
615  // Check for single block, unit-stride for-loop that is generated by
616  // sparsifier, which means no data dependence analysis is required,
617  // and its loop-body is very restricted in form.
618  if (!op.getRegion().hasOneBlock() || !isOneInteger(op.getStep()) ||
620  return failure();
621  // Analyze (!codegen) and rewrite (codegen) loop-body.
622  if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
623  vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
624  return success();
625  return failure();
626  }
627 
628 private:
629  const VL vl;
630 };
631 
632 static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
633  Value inp) {
634  if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
635  if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
636  if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
637  rewriter.replaceOp(op, redOp.getVector());
638  return success();
639  }
640  }
641  }
642  return failure();
643 }
644 
645 /// Reduction chain cleanup.
646 /// v = for { }
647 /// s = vsum(v) v = for { }
648 /// u = broadcast(s) -> for (v) { }
649 /// for (u) { }
650 struct ReducChainBroadcastRewriter
651  : public OpRewritePattern<vector::BroadcastOp> {
652 public:
654 
655  LogicalResult matchAndRewrite(vector::BroadcastOp op,
656  PatternRewriter &rewriter) const override {
657  return cleanReducChain(rewriter, op, op.getSource());
658  }
659 };
660 
661 /// Reduction chain cleanup.
662 /// v = for { }
663 /// s = vsum(v) v = for { }
664 /// u = insert(s) -> for (v) { }
665 /// for (u) { }
666 struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
667 public:
669 
670  LogicalResult matchAndRewrite(vector::InsertOp op,
671  PatternRewriter &rewriter) const override {
672  return cleanReducChain(rewriter, op, op.getValueToStore());
673  }
674 };
675 } // namespace
676 
677 //===----------------------------------------------------------------------===//
678 // Public method for populating vectorization rules.
679 //===----------------------------------------------------------------------===//
680 
681 /// Populates the given patterns list with vectorization rules.
683  unsigned vectorLength,
684  bool enableVLAVectorization,
685  bool enableSIMDIndex32) {
686  assert(vectorLength > 0);
688  patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
689  enableVLAVectorization, enableSIMDIndex32);
690  patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
691  patterns.getContext());
692 }
static Type getElementType(Type type)
Determine the element type of type.
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define UNAOP(xxx)
#define BINOP(xxx)
#define TYPEDUNAOP(xxx)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
reverse_iterator rbegin()
Definition: Block.h:145
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:363
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:359
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
unsigned getNumOperands()
Definition: Operation.h:346
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:243
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:168
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314