MLIR  20.0.0git
SparseVectorization.cpp
Go to the documentation of this file.
1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts loops generated by the sparsifier into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparsifier can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "Utils/CodegenUtils.h"
20 #include "Utils/LoopEmitter.h"
21 
31 #include "mlir/IR/Matchers.h"
32 
33 using namespace mlir;
34 using namespace mlir::sparse_tensor;
35 
36 namespace {
37 
38 /// Target SIMD properties:
39 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
40 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve)
41 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
42 struct VL {
43  unsigned vectorLength;
44  bool enableVLAVectorization;
45  bool enableSIMDIndex32;
46 };
47 
48 /// Helper test for invariant value (defined outside given block).
49 static bool isInvariantValue(Value val, Block *block) {
50  return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
51 }
52 
53 /// Helper test for invariant argument (defined outside given block).
54 static bool isInvariantArg(BlockArgument arg, Block *block) {
55  return arg.getOwner() != block;
56 }
57 
58 /// Constructs vector type for element type.
59 static VectorType vectorType(VL vl, Type etp) {
60  return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
61 }
62 
63 /// Constructs vector type from a memref value.
64 static VectorType vectorType(VL vl, Value mem) {
65  return vectorType(vl, getMemRefType(mem).getElementType());
66 }
67 
68 /// Constructs vector iteration mask.
69 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
70  Value iv, Value lo, Value hi, Value step) {
71  VectorType mtp = vectorType(vl, rewriter.getI1Type());
72  // Special case if the vector length evenly divides the trip count (for
73  // example, "for i = 0, 128, 16"). A constant all-true mask is generated
74  // so that all subsequent masked memory operations are immediately folded
75  // into unconditional memory operations.
76  IntegerAttr loInt, hiInt, stepInt;
77  if (matchPattern(lo, m_Constant(&loInt)) &&
78  matchPattern(hi, m_Constant(&hiInt)) &&
79  matchPattern(step, m_Constant(&stepInt))) {
80  if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
81  Value trueVal = constantI1(rewriter, loc, true);
82  return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
83  }
84  }
85  // Otherwise, generate a vector mask that avoids overrunning the upperbound
86  // during vector execution. Here we rely on subsequent loop optimizations to
87  // avoid executing the mask in all iterations, for example, by splitting the
88  // loop into an unconditional vector loop and a scalar cleanup loop.
89  auto min = AffineMap::get(
90  /*dimCount=*/2, /*symbolCount=*/1,
91  {rewriter.getAffineSymbolExpr(0),
92  rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
93  rewriter.getContext());
94  Value end = rewriter.createOrFold<affine::AffineMinOp>(
95  loc, min, ValueRange{hi, iv, step});
96  return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
97 }
98 
99 /// Generates a vectorized invariant. Here we rely on subsequent loop
100 /// optimizations to hoist the invariant broadcast out of the vector loop.
101 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
102  Value val) {
103  VectorType vtp = vectorType(vl, val.getType());
104  return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
105 }
106 
107 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
108 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
109 /// that the sparsifier can only generate indirect loads in
110 /// the last index, i.e. back().
111 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
112  Value mem, ArrayRef<Value> idxs, Value vmask) {
113  VectorType vtp = vectorType(vl, mem);
114  Value pass = constantZero(rewriter, loc, vtp);
115  if (llvm::isa<VectorType>(idxs.back().getType())) {
116  SmallVector<Value> scalarArgs(idxs);
117  Value indexVec = idxs.back();
118  scalarArgs.back() = constantIndex(rewriter, loc, 0);
119  return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs,
120  indexVec, vmask, pass);
121  }
122  return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask,
123  pass);
124 }
125 
126 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
127 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
128 /// that the sparsifier can only generate indirect stores in
129 /// the last index, i.e. back().
130 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
131  ArrayRef<Value> idxs, Value vmask, Value rhs) {
132  if (llvm::isa<VectorType>(idxs.back().getType())) {
133  SmallVector<Value> scalarArgs(idxs);
134  Value indexVec = idxs.back();
135  scalarArgs.back() = constantIndex(rewriter, loc, 0);
136  rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask,
137  rhs);
138  return;
139  }
140  rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs);
141 }
142 
143 /// Detects a vectorizable reduction operations and returns the
144 /// combining kind of reduction on success in `kind`.
145 static bool isVectorizableReduction(Value red, Value iter,
146  vector::CombiningKind &kind) {
147  if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
148  kind = vector::CombiningKind::ADD;
149  return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
150  }
151  if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
152  kind = vector::CombiningKind::ADD;
153  return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
154  }
155  if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
156  kind = vector::CombiningKind::ADD;
157  return subf->getOperand(0) == iter;
158  }
159  if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
160  kind = vector::CombiningKind::ADD;
161  return subi->getOperand(0) == iter;
162  }
163  if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
164  kind = vector::CombiningKind::MUL;
165  return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
166  }
167  if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
168  kind = vector::CombiningKind::MUL;
169  return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
170  }
171  if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
172  kind = vector::CombiningKind::AND;
173  return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
174  }
175  if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
176  kind = vector::CombiningKind::OR;
177  return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
178  }
179  if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
180  kind = vector::CombiningKind::XOR;
181  return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
182  }
183  return false;
184 }
185 
186 /// Generates an initial value for a vector reduction, following the scheme
187 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
188 /// initial scalar value is correctly embedded in the vector reduction value,
189 /// and a straightforward horizontal reduction will complete the operation.
190 /// Value 'r' denotes the initial value of the reduction outside the loop.
191 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
192  Value red, Value iter, Value r,
193  VectorType vtp) {
194  vector::CombiningKind kind;
195  if (!isVectorizableReduction(red, iter, kind))
196  llvm_unreachable("unknown reduction");
197  switch (kind) {
198  case vector::CombiningKind::ADD:
199  case vector::CombiningKind::XOR:
200  // Initialize reduction vector to: | 0 | .. | 0 | r |
201  return rewriter.create<vector::InsertElementOp>(
202  loc, r, constantZero(rewriter, loc, vtp),
203  constantIndex(rewriter, loc, 0));
204  case vector::CombiningKind::MUL:
205  // Initialize reduction vector to: | 1 | .. | 1 | r |
206  return rewriter.create<vector::InsertElementOp>(
207  loc, r, constantOne(rewriter, loc, vtp),
208  constantIndex(rewriter, loc, 0));
209  case vector::CombiningKind::AND:
210  case vector::CombiningKind::OR:
211  // Initialize reduction vector to: | r | .. | r | r |
212  return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
213  default:
214  break;
215  }
216  llvm_unreachable("unknown reduction kind");
217 }
218 
219 /// This method is called twice to analyze and rewrite the given subscripts.
220 /// The first call (!codegen) does the analysis. Then, on success, the second
221 /// call (codegen) yields the proper vector form in the output parameter
222 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
223 /// stay in sync. Note that the analyis part is simple because the sparsifier
224 /// only generates relatively simple subscript expressions.
225 ///
226 /// See https://llvm.org/docs/GetElementPtr.html for some background on
227 /// the complications described below.
228 ///
229 /// We need to generate a position/coordinate load from the sparse storage
230 /// scheme. Narrower data types need to be zero extended before casting
231 /// the value into the `index` type used for looping and indexing.
232 ///
233 /// For the scalar case, subscripts simply zero extend narrower indices
234 /// into 64-bit values before casting to an index type without a performance
235 /// penalty. Indices that already are 64-bit, in theory, cannot express the
236 /// full range since the LLVM backend defines addressing in terms of an
237 /// unsigned pointer/signed index pair.
238 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
239  VL vl, ValueRange subs, bool codegen,
240  Value vmask, SmallVectorImpl<Value> &idxs) {
241  unsigned d = 0;
242  unsigned dim = subs.size();
243  Block *block = &forOp.getRegion().front();
244  for (auto sub : subs) {
245  bool innermost = ++d == dim;
246  // Invariant subscripts in outer dimensions simply pass through.
247  // Note that we rely on LICM to hoist loads where all subscripts
248  // are invariant in the innermost loop.
249  // Example:
250  // a[inv][i] for inv
251  if (isInvariantValue(sub, block)) {
252  if (innermost)
253  return false;
254  if (codegen)
255  idxs.push_back(sub);
256  continue; // success so far
257  }
258  // Invariant block arguments (including outer loop indices) in outer
259  // dimensions simply pass through. Direct loop indices in the
260  // innermost loop simply pass through as well.
261  // Example:
262  // a[i][j] for both i and j
263  if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
264  if (isInvariantArg(arg, block) == innermost)
265  return false;
266  if (codegen)
267  idxs.push_back(sub);
268  continue; // success so far
269  }
270  // Look under the hood of casting.
271  auto cast = sub;
272  while (true) {
273  if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
274  cast = icast->getOperand(0);
275  else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
276  cast = ecast->getOperand(0);
277  else
278  break;
279  }
280  // Since the index vector is used in a subsequent gather/scatter
281  // operations, which effectively defines an unsigned pointer + signed
282  // index, we must zero extend the vector to an index width. For 8-bit
283  // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
284  // zero extending the elements into 64-bit loses some performance since
285  // the 32-bit indexed gather/scatter is more efficient than the 64-bit
286  // index variant (if the negative 32-bit index space is unused, the
287  // enableSIMDIndex32 flag can preserve this performance). For 64-bit
288  // values, there is no good way to state that the indices are unsigned,
289  // which creates the potential of incorrect address calculations in the
290  // unlikely case we need such extremely large offsets.
291  // Example:
292  // a[ ind[i] ]
293  if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
294  if (!innermost)
295  return false;
296  if (codegen) {
297  SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
298  Location loc = forOp.getLoc();
299  Value vload =
300  genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
301  Type etp = llvm::cast<VectorType>(vload.getType()).getElementType();
302  if (!llvm::isa<IndexType>(etp)) {
303  if (etp.getIntOrFloatBitWidth() < 32)
304  vload = rewriter.create<arith::ExtUIOp>(
305  loc, vectorType(vl, rewriter.getI32Type()), vload);
306  else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
307  vload = rewriter.create<arith::ExtUIOp>(
308  loc, vectorType(vl, rewriter.getI64Type()), vload);
309  }
310  idxs.push_back(vload);
311  }
312  continue; // success so far
313  }
314  // Address calculation 'i = add inv, idx' (after LICM).
315  // Example:
316  // a[base + i]
317  if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
318  Value inv = load.getOperand(0);
319  Value idx = load.getOperand(1);
320  // Swap non-invariant.
321  if (!isInvariantValue(inv, block)) {
322  inv = idx;
323  idx = load.getOperand(0);
324  }
325  // Inspect.
326  if (isInvariantValue(inv, block)) {
327  if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
328  if (isInvariantArg(arg, block) || !innermost)
329  return false;
330  if (codegen)
331  idxs.push_back(
332  rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
333  continue; // success so far
334  }
335  }
336  }
337  return false;
338  }
339  return true;
340 }
341 
342 #define UNAOP(xxx) \
343  if (isa<xxx>(def)) { \
344  if (codegen) \
345  vexp = rewriter.create<xxx>(loc, vx); \
346  return true; \
347  }
348 
349 #define TYPEDUNAOP(xxx) \
350  if (auto x = dyn_cast<xxx>(def)) { \
351  if (codegen) { \
352  VectorType vtp = vectorType(vl, x.getType()); \
353  vexp = rewriter.create<xxx>(loc, vtp, vx); \
354  } \
355  return true; \
356  }
357 
358 #define BINOP(xxx) \
359  if (isa<xxx>(def)) { \
360  if (codegen) \
361  vexp = rewriter.create<xxx>(loc, vx, vy); \
362  return true; \
363  }
364 
365 /// This method is called twice to analyze and rewrite the given expression.
366 /// The first call (!codegen) does the analysis. Then, on success, the second
367 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
368 /// This mechanism ensures that analysis and rewriting code stay in sync. Note
369 /// that the analyis part is simple because the sparsifier only generates
370 /// relatively simple expressions inside the for-loops.
371 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
372  Value exp, bool codegen, Value vmask, Value &vexp) {
373  Location loc = forOp.getLoc();
374  // Reject unsupported types.
375  if (!VectorType::isValidElementType(exp.getType()))
376  return false;
377  // A block argument is invariant/reduction/index.
378  if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
379  if (arg == forOp.getInductionVar()) {
380  // We encountered a single, innermost index inside the computation,
381  // such as a[i] = i, which must convert to [i, i+1, ...].
382  if (codegen) {
383  VectorType vtp = vectorType(vl, arg.getType());
384  Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg);
385  Value incr = rewriter.create<vector::StepOp>(loc, vtp);
386  vexp = rewriter.create<arith::AddIOp>(loc, veci, incr);
387  }
388  return true;
389  }
390  // An invariant or reduction. In both cases, we treat this as an
391  // invariant value, and rely on later replacing and folding to
392  // construct a proper reduction chain for the latter case.
393  if (codegen)
394  vexp = genVectorInvariantValue(rewriter, vl, exp);
395  return true;
396  }
397  // Something defined outside the loop-body is invariant.
398  Operation *def = exp.getDefiningOp();
399  Block *block = &forOp.getRegion().front();
400  if (def->getBlock() != block) {
401  if (codegen)
402  vexp = genVectorInvariantValue(rewriter, vl, exp);
403  return true;
404  }
405  // Proper load operations. These are either values involved in the
406  // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
407  // or coordinate values inside the computation that are now fetched from
408  // the sparse storage coordinates arrays, such as a[i] = i becomes
409  // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
410  // and 'hi = lo + vl - 1'.
411  if (auto load = dyn_cast<memref::LoadOp>(def)) {
412  auto subs = load.getIndices();
413  SmallVector<Value> idxs;
414  if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
415  if (codegen)
416  vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
417  return true;
418  }
419  return false;
420  }
421  // Inside loop-body unary and binary operations. Note that it would be
422  // nicer if we could somehow test and build the operations in a more
423  // concise manner than just listing them all (although this way we know
424  // for certain that they can vectorize).
425  //
426  // TODO: avoid visiting CSEs multiple times
427  //
428  if (def->getNumOperands() == 1) {
429  Value vx;
430  if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
431  vx)) {
432  UNAOP(math::AbsFOp)
433  UNAOP(math::AbsIOp)
434  UNAOP(math::CeilOp)
435  UNAOP(math::FloorOp)
436  UNAOP(math::SqrtOp)
437  UNAOP(math::ExpM1Op)
438  UNAOP(math::Log1pOp)
439  UNAOP(math::SinOp)
440  UNAOP(math::TanhOp)
441  UNAOP(arith::NegFOp)
442  TYPEDUNAOP(arith::TruncFOp)
443  TYPEDUNAOP(arith::ExtFOp)
444  TYPEDUNAOP(arith::FPToSIOp)
445  TYPEDUNAOP(arith::FPToUIOp)
446  TYPEDUNAOP(arith::SIToFPOp)
447  TYPEDUNAOP(arith::UIToFPOp)
448  TYPEDUNAOP(arith::ExtSIOp)
449  TYPEDUNAOP(arith::ExtUIOp)
450  TYPEDUNAOP(arith::IndexCastOp)
451  TYPEDUNAOP(arith::TruncIOp)
452  TYPEDUNAOP(arith::BitcastOp)
453  // TODO: complex?
454  }
455  } else if (def->getNumOperands() == 2) {
456  Value vx, vy;
457  if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
458  vx) &&
459  vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
460  vy)) {
461  // We only accept shift-by-invariant (where the same shift factor applies
462  // to all packed elements). In the vector dialect, this is still
463  // represented with an expanded vector at the right-hand-side, however,
464  // so that we do not have to special case the code generation.
465  if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
466  isa<arith::ShRSIOp>(def)) {
467  Value shiftFactor = def->getOperand(1);
468  if (!isInvariantValue(shiftFactor, block))
469  return false;
470  }
471  // Generate code.
472  BINOP(arith::MulFOp)
473  BINOP(arith::MulIOp)
474  BINOP(arith::DivFOp)
475  BINOP(arith::DivSIOp)
476  BINOP(arith::DivUIOp)
477  BINOP(arith::AddFOp)
478  BINOP(arith::AddIOp)
479  BINOP(arith::SubFOp)
480  BINOP(arith::SubIOp)
481  BINOP(arith::AndIOp)
482  BINOP(arith::OrIOp)
483  BINOP(arith::XOrIOp)
484  BINOP(arith::ShLIOp)
485  BINOP(arith::ShRUIOp)
486  BINOP(arith::ShRSIOp)
487  // TODO: complex?
488  }
489  }
490  return false;
491 }
492 
493 #undef UNAOP
494 #undef TYPEDUNAOP
495 #undef BINOP
496 
497 /// This method is called twice to analyze and rewrite the given for-loop.
498 /// The first call (!codegen) does the analysis. Then, on success, the second
499 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
500 /// that analysis and rewriting code stay in sync.
501 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
502  bool codegen) {
503  Block &block = forOp.getRegion().front();
504  // For loops with single yield statement (as below) could be generated
505  // when custom reduce is used with unary operation.
506  // for (...)
507  // yield c_0
508  if (block.getOperations().size() <= 1)
509  return false;
510 
511  Location loc = forOp.getLoc();
512  scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
513  auto &last = *++block.rbegin();
514  scf::ForOp forOpNew;
515 
516  // Perform initial set up during codegen (we know that the first analysis
517  // pass was successful). For reductions, we need to construct a completely
518  // new for-loop, since the incoming and outgoing reduction type
519  // changes into SIMD form. For stores, we can simply adjust the stride
520  // and insert in the existing for-loop. In both cases, we set up a vector
521  // mask for all operations which takes care of confining vectors to
522  // the original iteration space (later cleanup loops or other
523  // optimizations can take care of those).
524  Value vmask;
525  if (codegen) {
526  Value step = constantIndex(rewriter, loc, vl.vectorLength);
527  if (vl.enableVLAVectorization) {
528  Value vscale =
529  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
530  step = rewriter.create<arith::MulIOp>(loc, vscale, step);
531  }
532  if (!yield.getResults().empty()) {
533  Value init = forOp.getInitArgs()[0];
534  VectorType vtp = vectorType(vl, init.getType());
535  Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
536  forOp.getRegionIterArg(0), init, vtp);
537  forOpNew = rewriter.create<scf::ForOp>(
538  loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
539  forOpNew->setAttr(
541  forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
542  rewriter.setInsertionPointToStart(forOpNew.getBody());
543  } else {
544  rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); });
545  rewriter.setInsertionPoint(yield);
546  }
547  vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
548  forOp.getLowerBound(), forOp.getUpperBound(), step);
549  }
550 
551  // Sparse for-loops either are terminated by a non-empty yield operation
552  // (reduction loop) or otherwise by a store operation (pararallel loop).
553  if (!yield.getResults().empty()) {
554  // Analyze/vectorize reduction.
555  if (yield->getNumOperands() != 1)
556  return false;
557  Value red = yield->getOperand(0);
558  Value iter = forOp.getRegionIterArg(0);
559  vector::CombiningKind kind;
560  Value vrhs;
561  if (isVectorizableReduction(red, iter, kind) &&
562  vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
563  if (codegen) {
564  Value partial = forOpNew.getResult(0);
565  Value vpass = genVectorInvariantValue(rewriter, vl, iter);
566  Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
567  rewriter.create<scf::YieldOp>(loc, vred);
568  rewriter.setInsertionPointAfter(forOpNew);
569  Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
570  // Now do some relinking (last one is not completely type safe
571  // but all bad ones are removed right away). This also folds away
572  // nop broadcast operations.
573  rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
574  rewriter.replaceAllUsesWith(forOp.getInductionVar(),
575  forOpNew.getInductionVar());
576  rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
577  forOpNew.getRegionIterArg(0));
578  rewriter.eraseOp(forOp);
579  }
580  return true;
581  }
582  } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
583  // Analyze/vectorize store operation.
584  auto subs = store.getIndices();
585  SmallVector<Value> idxs;
586  Value rhs = store.getValue();
587  Value vrhs;
588  if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
589  vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
590  if (codegen) {
591  genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
592  rewriter.eraseOp(store);
593  }
594  return true;
595  }
596  }
597 
598  assert(!codegen && "cannot call codegen when analysis failed");
599  return false;
600 }
601 
602 /// Basic for-loop vectorizer.
603 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
604 public:
606 
607  ForOpRewriter(MLIRContext *context, unsigned vectorLength,
608  bool enableVLAVectorization, bool enableSIMDIndex32)
609  : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
610  enableSIMDIndex32} {}
611 
612  LogicalResult matchAndRewrite(scf::ForOp op,
613  PatternRewriter &rewriter) const override {
614  // Check for single block, unit-stride for-loop that is generated by
615  // sparsifier, which means no data dependence analysis is required,
616  // and its loop-body is very restricted in form.
617  if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
619  return failure();
620  // Analyze (!codegen) and rewrite (codegen) loop-body.
621  if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
622  vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
623  return success();
624  return failure();
625  }
626 
627 private:
628  const VL vl;
629 };
630 
631 /// Reduction chain cleanup.
632 /// v = for { }
633 /// s = vsum(v) v = for { }
634 /// u = expand(s) -> for (v) { }
635 /// for (u) { }
636 template <typename VectorOp>
637 struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
638 public:
640 
641  LogicalResult matchAndRewrite(VectorOp op,
642  PatternRewriter &rewriter) const override {
643  Value inp = op.getSource();
644  if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
645  if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
646  if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
647  rewriter.replaceOp(op, redOp.getVector());
648  return success();
649  }
650  }
651  }
652  return failure();
653  }
654 };
655 
656 } // namespace
657 
658 //===----------------------------------------------------------------------===//
659 // Public method for populating vectorization rules.
660 //===----------------------------------------------------------------------===//
661 
662 /// Populates the given patterns list with vectorization rules.
664  unsigned vectorLength,
665  bool enableVLAVectorization,
666  bool enableSIMDIndex32) {
667  assert(vectorLength > 0);
669  patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
670  enableVLAVectorization, enableSIMDIndex32);
671  patterns.add<ReducChainRewriter<vector::InsertElementOp>,
672  ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
673 }
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
#define UNAOP(xxx)
#define BINOP(xxx)
#define TYPEDUNAOP(xxx)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
reverse_iterator rbegin()
Definition: Block.h:145
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:408
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
unsigned getNumOperands()
Definition: Operation.h:341
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:243
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:309
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:320
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:356
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:168
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358