MLIR  20.0.0git
SparseVectorization.cpp
Go to the documentation of this file.
1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts loops generated by the sparsifier into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparsifier can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "Utils/CodegenUtils.h"
20 #include "Utils/LoopEmitter.h"
21 
30 #include "mlir/IR/Matchers.h"
31 
32 using namespace mlir;
33 using namespace mlir::sparse_tensor;
34 
35 namespace {
36 
37 /// Target SIMD properties:
38 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
39 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve)
40 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
41 struct VL {
42  unsigned vectorLength;
43  bool enableVLAVectorization;
44  bool enableSIMDIndex32;
45 };
46 
47 /// Helper test for invariant value (defined outside given block).
48 static bool isInvariantValue(Value val, Block *block) {
49  return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
50 }
51 
52 /// Helper test for invariant argument (defined outside given block).
53 static bool isInvariantArg(BlockArgument arg, Block *block) {
54  return arg.getOwner() != block;
55 }
56 
57 /// Constructs vector type for element type.
58 static VectorType vectorType(VL vl, Type etp) {
59  return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
60 }
61 
62 /// Constructs vector type from a memref value.
63 static VectorType vectorType(VL vl, Value mem) {
64  return vectorType(vl, getMemRefType(mem).getElementType());
65 }
66 
67 /// Constructs vector iteration mask.
68 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
69  Value iv, Value lo, Value hi, Value step) {
70  VectorType mtp = vectorType(vl, rewriter.getI1Type());
71  // Special case if the vector length evenly divides the trip count (for
72  // example, "for i = 0, 128, 16"). A constant all-true mask is generated
73  // so that all subsequent masked memory operations are immediately folded
74  // into unconditional memory operations.
75  IntegerAttr loInt, hiInt, stepInt;
76  if (matchPattern(lo, m_Constant(&loInt)) &&
77  matchPattern(hi, m_Constant(&hiInt)) &&
78  matchPattern(step, m_Constant(&stepInt))) {
79  if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
80  Value trueVal = constantI1(rewriter, loc, true);
81  return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
82  }
83  }
84  // Otherwise, generate a vector mask that avoids overrunning the upperbound
85  // during vector execution. Here we rely on subsequent loop optimizations to
86  // avoid executing the mask in all iterations, for example, by splitting the
87  // loop into an unconditional vector loop and a scalar cleanup loop.
88  auto min = AffineMap::get(
89  /*dimCount=*/2, /*symbolCount=*/1,
90  {rewriter.getAffineSymbolExpr(0),
91  rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
92  rewriter.getContext());
93  Value end = rewriter.createOrFold<affine::AffineMinOp>(
94  loc, min, ValueRange{hi, iv, step});
95  return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
96 }
97 
98 /// Generates a vectorized invariant. Here we rely on subsequent loop
99 /// optimizations to hoist the invariant broadcast out of the vector loop.
100 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
101  Value val) {
102  VectorType vtp = vectorType(vl, val.getType());
103  return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
104 }
105 
106 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
107 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
108 /// that the sparsifier can only generate indirect loads in
109 /// the last index, i.e. back().
110 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
111  Value mem, ArrayRef<Value> idxs, Value vmask) {
112  VectorType vtp = vectorType(vl, mem);
113  Value pass = constantZero(rewriter, loc, vtp);
114  if (llvm::isa<VectorType>(idxs.back().getType())) {
115  SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
116  Value indexVec = idxs.back();
117  scalarArgs.back() = constantIndex(rewriter, loc, 0);
118  return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs,
119  indexVec, vmask, pass);
120  }
121  return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask,
122  pass);
123 }
124 
125 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
126 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
127 /// that the sparsifier can only generate indirect stores in
128 /// the last index, i.e. back().
129 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
130  ArrayRef<Value> idxs, Value vmask, Value rhs) {
131  if (llvm::isa<VectorType>(idxs.back().getType())) {
132  SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
133  Value indexVec = idxs.back();
134  scalarArgs.back() = constantIndex(rewriter, loc, 0);
135  rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask,
136  rhs);
137  return;
138  }
139  rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs);
140 }
141 
142 /// Detects a vectorizable reduction operations and returns the
143 /// combining kind of reduction on success in `kind`.
144 static bool isVectorizableReduction(Value red, Value iter,
145  vector::CombiningKind &kind) {
146  if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
147  kind = vector::CombiningKind::ADD;
148  return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
149  }
150  if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
151  kind = vector::CombiningKind::ADD;
152  return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
153  }
154  if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
155  kind = vector::CombiningKind::ADD;
156  return subf->getOperand(0) == iter;
157  }
158  if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
159  kind = vector::CombiningKind::ADD;
160  return subi->getOperand(0) == iter;
161  }
162  if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
163  kind = vector::CombiningKind::MUL;
164  return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
165  }
166  if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
167  kind = vector::CombiningKind::MUL;
168  return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
169  }
170  if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
171  kind = vector::CombiningKind::AND;
172  return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
173  }
174  if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
175  kind = vector::CombiningKind::OR;
176  return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
177  }
178  if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
179  kind = vector::CombiningKind::XOR;
180  return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
181  }
182  return false;
183 }
184 
185 /// Generates an initial value for a vector reduction, following the scheme
186 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
187 /// initial scalar value is correctly embedded in the vector reduction value,
188 /// and a straightforward horizontal reduction will complete the operation.
189 /// Value 'r' denotes the initial value of the reduction outside the loop.
190 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
191  Value red, Value iter, Value r,
192  VectorType vtp) {
193  vector::CombiningKind kind;
194  if (!isVectorizableReduction(red, iter, kind))
195  llvm_unreachable("unknown reduction");
196  switch (kind) {
197  case vector::CombiningKind::ADD:
198  case vector::CombiningKind::XOR:
199  // Initialize reduction vector to: | 0 | .. | 0 | r |
200  return rewriter.create<vector::InsertElementOp>(
201  loc, r, constantZero(rewriter, loc, vtp),
202  constantIndex(rewriter, loc, 0));
203  case vector::CombiningKind::MUL:
204  // Initialize reduction vector to: | 1 | .. | 1 | r |
205  return rewriter.create<vector::InsertElementOp>(
206  loc, r, constantOne(rewriter, loc, vtp),
207  constantIndex(rewriter, loc, 0));
208  case vector::CombiningKind::AND:
209  case vector::CombiningKind::OR:
210  // Initialize reduction vector to: | r | .. | r | r |
211  return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
212  default:
213  break;
214  }
215  llvm_unreachable("unknown reduction kind");
216 }
217 
218 /// This method is called twice to analyze and rewrite the given subscripts.
219 /// The first call (!codegen) does the analysis. Then, on success, the second
220 /// call (codegen) yields the proper vector form in the output parameter
221 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
222 /// stay in sync. Note that the analyis part is simple because the sparsifier
223 /// only generates relatively simple subscript expressions.
224 ///
225 /// See https://llvm.org/docs/GetElementPtr.html for some background on
226 /// the complications described below.
227 ///
228 /// We need to generate a position/coordinate load from the sparse storage
229 /// scheme. Narrower data types need to be zero extended before casting
230 /// the value into the `index` type used for looping and indexing.
231 ///
232 /// For the scalar case, subscripts simply zero extend narrower indices
233 /// into 64-bit values before casting to an index type without a performance
234 /// penalty. Indices that already are 64-bit, in theory, cannot express the
235 /// full range since the LLVM backend defines addressing in terms of an
236 /// unsigned pointer/signed index pair.
237 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
238  VL vl, ValueRange subs, bool codegen,
239  Value vmask, SmallVectorImpl<Value> &idxs) {
240  unsigned d = 0;
241  unsigned dim = subs.size();
242  Block *block = &forOp.getRegion().front();
243  for (auto sub : subs) {
244  bool innermost = ++d == dim;
245  // Invariant subscripts in outer dimensions simply pass through.
246  // Note that we rely on LICM to hoist loads where all subscripts
247  // are invariant in the innermost loop.
248  // Example:
249  // a[inv][i] for inv
250  if (isInvariantValue(sub, block)) {
251  if (innermost)
252  return false;
253  if (codegen)
254  idxs.push_back(sub);
255  continue; // success so far
256  }
257  // Invariant block arguments (including outer loop indices) in outer
258  // dimensions simply pass through. Direct loop indices in the
259  // innermost loop simply pass through as well.
260  // Example:
261  // a[i][j] for both i and j
262  if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
263  if (isInvariantArg(arg, block) == innermost)
264  return false;
265  if (codegen)
266  idxs.push_back(sub);
267  continue; // success so far
268  }
269  // Look under the hood of casting.
270  auto cast = sub;
271  while (true) {
272  if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
273  cast = icast->getOperand(0);
274  else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
275  cast = ecast->getOperand(0);
276  else
277  break;
278  }
279  // Since the index vector is used in a subsequent gather/scatter
280  // operations, which effectively defines an unsigned pointer + signed
281  // index, we must zero extend the vector to an index width. For 8-bit
282  // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
283  // zero extending the elements into 64-bit loses some performance since
284  // the 32-bit indexed gather/scatter is more efficient than the 64-bit
285  // index variant (if the negative 32-bit index space is unused, the
286  // enableSIMDIndex32 flag can preserve this performance). For 64-bit
287  // values, there is no good way to state that the indices are unsigned,
288  // which creates the potential of incorrect address calculations in the
289  // unlikely case we need such extremely large offsets.
290  // Example:
291  // a[ ind[i] ]
292  if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
293  if (!innermost)
294  return false;
295  if (codegen) {
296  SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
297  Location loc = forOp.getLoc();
298  Value vload =
299  genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
300  Type etp = llvm::cast<VectorType>(vload.getType()).getElementType();
301  if (!llvm::isa<IndexType>(etp)) {
302  if (etp.getIntOrFloatBitWidth() < 32)
303  vload = rewriter.create<arith::ExtUIOp>(
304  loc, vectorType(vl, rewriter.getI32Type()), vload);
305  else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
306  vload = rewriter.create<arith::ExtUIOp>(
307  loc, vectorType(vl, rewriter.getI64Type()), vload);
308  }
309  idxs.push_back(vload);
310  }
311  continue; // success so far
312  }
313  // Address calculation 'i = add inv, idx' (after LICM).
314  // Example:
315  // a[base + i]
316  if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
317  Value inv = load.getOperand(0);
318  Value idx = load.getOperand(1);
319  // Swap non-invariant.
320  if (!isInvariantValue(inv, block)) {
321  inv = idx;
322  idx = load.getOperand(0);
323  }
324  // Inspect.
325  if (isInvariantValue(inv, block)) {
326  if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
327  if (isInvariantArg(arg, block) || !innermost)
328  return false;
329  if (codegen)
330  idxs.push_back(
331  rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
332  continue; // success so far
333  }
334  }
335  }
336  return false;
337  }
338  return true;
339 }
340 
341 #define UNAOP(xxx) \
342  if (isa<xxx>(def)) { \
343  if (codegen) \
344  vexp = rewriter.create<xxx>(loc, vx); \
345  return true; \
346  }
347 
348 #define TYPEDUNAOP(xxx) \
349  if (auto x = dyn_cast<xxx>(def)) { \
350  if (codegen) { \
351  VectorType vtp = vectorType(vl, x.getType()); \
352  vexp = rewriter.create<xxx>(loc, vtp, vx); \
353  } \
354  return true; \
355  }
356 
357 #define BINOP(xxx) \
358  if (isa<xxx>(def)) { \
359  if (codegen) \
360  vexp = rewriter.create<xxx>(loc, vx, vy); \
361  return true; \
362  }
363 
364 /// This method is called twice to analyze and rewrite the given expression.
365 /// The first call (!codegen) does the analysis. Then, on success, the second
366 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
367 /// This mechanism ensures that analysis and rewriting code stay in sync. Note
368 /// that the analyis part is simple because the sparsifier only generates
369 /// relatively simple expressions inside the for-loops.
370 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
371  Value exp, bool codegen, Value vmask, Value &vexp) {
372  Location loc = forOp.getLoc();
373  // Reject unsupported types.
374  if (!VectorType::isValidElementType(exp.getType()))
375  return false;
376  // A block argument is invariant/reduction/index.
377  if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
378  if (arg == forOp.getInductionVar()) {
379  // We encountered a single, innermost index inside the computation,
380  // such as a[i] = i, which must convert to [i, i+1, ...].
381  if (codegen) {
382  VectorType vtp = vectorType(vl, arg.getType());
383  Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg);
384  Value incr = rewriter.create<vector::StepOp>(loc, vtp);
385  vexp = rewriter.create<arith::AddIOp>(loc, veci, incr);
386  }
387  return true;
388  }
389  // An invariant or reduction. In both cases, we treat this as an
390  // invariant value, and rely on later replacing and folding to
391  // construct a proper reduction chain for the latter case.
392  if (codegen)
393  vexp = genVectorInvariantValue(rewriter, vl, exp);
394  return true;
395  }
396  // Something defined outside the loop-body is invariant.
397  Operation *def = exp.getDefiningOp();
398  Block *block = &forOp.getRegion().front();
399  if (def->getBlock() != block) {
400  if (codegen)
401  vexp = genVectorInvariantValue(rewriter, vl, exp);
402  return true;
403  }
404  // Proper load operations. These are either values involved in the
405  // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
406  // or coordinate values inside the computation that are now fetched from
407  // the sparse storage coordinates arrays, such as a[i] = i becomes
408  // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
409  // and 'hi = lo + vl - 1'.
410  if (auto load = dyn_cast<memref::LoadOp>(def)) {
411  auto subs = load.getIndices();
412  SmallVector<Value> idxs;
413  if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
414  if (codegen)
415  vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
416  return true;
417  }
418  return false;
419  }
420  // Inside loop-body unary and binary operations. Note that it would be
421  // nicer if we could somehow test and build the operations in a more
422  // concise manner than just listing them all (although this way we know
423  // for certain that they can vectorize).
424  //
425  // TODO: avoid visiting CSEs multiple times
426  //
427  if (def->getNumOperands() == 1) {
428  Value vx;
429  if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
430  vx)) {
431  UNAOP(math::AbsFOp)
432  UNAOP(math::AbsIOp)
433  UNAOP(math::CeilOp)
434  UNAOP(math::FloorOp)
435  UNAOP(math::SqrtOp)
436  UNAOP(math::ExpM1Op)
437  UNAOP(math::Log1pOp)
438  UNAOP(math::SinOp)
439  UNAOP(math::TanhOp)
440  UNAOP(arith::NegFOp)
441  TYPEDUNAOP(arith::TruncFOp)
442  TYPEDUNAOP(arith::ExtFOp)
443  TYPEDUNAOP(arith::FPToSIOp)
444  TYPEDUNAOP(arith::FPToUIOp)
445  TYPEDUNAOP(arith::SIToFPOp)
446  TYPEDUNAOP(arith::UIToFPOp)
447  TYPEDUNAOP(arith::ExtSIOp)
448  TYPEDUNAOP(arith::ExtUIOp)
449  TYPEDUNAOP(arith::IndexCastOp)
450  TYPEDUNAOP(arith::TruncIOp)
451  TYPEDUNAOP(arith::BitcastOp)
452  // TODO: complex?
453  }
454  } else if (def->getNumOperands() == 2) {
455  Value vx, vy;
456  if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
457  vx) &&
458  vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
459  vy)) {
460  // We only accept shift-by-invariant (where the same shift factor applies
461  // to all packed elements). In the vector dialect, this is still
462  // represented with an expanded vector at the right-hand-side, however,
463  // so that we do not have to special case the code generation.
464  if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
465  isa<arith::ShRSIOp>(def)) {
466  Value shiftFactor = def->getOperand(1);
467  if (!isInvariantValue(shiftFactor, block))
468  return false;
469  }
470  // Generate code.
471  BINOP(arith::MulFOp)
472  BINOP(arith::MulIOp)
473  BINOP(arith::DivFOp)
474  BINOP(arith::DivSIOp)
475  BINOP(arith::DivUIOp)
476  BINOP(arith::AddFOp)
477  BINOP(arith::AddIOp)
478  BINOP(arith::SubFOp)
479  BINOP(arith::SubIOp)
480  BINOP(arith::AndIOp)
481  BINOP(arith::OrIOp)
482  BINOP(arith::XOrIOp)
483  BINOP(arith::ShLIOp)
484  BINOP(arith::ShRUIOp)
485  BINOP(arith::ShRSIOp)
486  // TODO: complex?
487  }
488  }
489  return false;
490 }
491 
492 #undef UNAOP
493 #undef TYPEDUNAOP
494 #undef BINOP
495 
496 /// This method is called twice to analyze and rewrite the given for-loop.
497 /// The first call (!codegen) does the analysis. Then, on success, the second
498 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
499 /// that analysis and rewriting code stay in sync.
500 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
501  bool codegen) {
502  Block &block = forOp.getRegion().front();
503  // For loops with single yield statement (as below) could be generated
504  // when custom reduce is used with unary operation.
505  // for (...)
506  // yield c_0
507  if (block.getOperations().size() <= 1)
508  return false;
509 
510  Location loc = forOp.getLoc();
511  scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
512  auto &last = *++block.rbegin();
513  scf::ForOp forOpNew;
514 
515  // Perform initial set up during codegen (we know that the first analysis
516  // pass was successful). For reductions, we need to construct a completely
517  // new for-loop, since the incoming and outgoing reduction type
518  // changes into SIMD form. For stores, we can simply adjust the stride
519  // and insert in the existing for-loop. In both cases, we set up a vector
520  // mask for all operations which takes care of confining vectors to
521  // the original iteration space (later cleanup loops or other
522  // optimizations can take care of those).
523  Value vmask;
524  if (codegen) {
525  Value step = constantIndex(rewriter, loc, vl.vectorLength);
526  if (vl.enableVLAVectorization) {
527  Value vscale =
528  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
529  step = rewriter.create<arith::MulIOp>(loc, vscale, step);
530  }
531  if (!yield.getResults().empty()) {
532  Value init = forOp.getInitArgs()[0];
533  VectorType vtp = vectorType(vl, init.getType());
534  Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
535  forOp.getRegionIterArg(0), init, vtp);
536  forOpNew = rewriter.create<scf::ForOp>(
537  loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
538  forOpNew->setAttr(
540  forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
541  rewriter.setInsertionPointToStart(forOpNew.getBody());
542  } else {
543  rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); });
544  rewriter.setInsertionPoint(yield);
545  }
546  vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
547  forOp.getLowerBound(), forOp.getUpperBound(), step);
548  }
549 
550  // Sparse for-loops either are terminated by a non-empty yield operation
551  // (reduction loop) or otherwise by a store operation (pararallel loop).
552  if (!yield.getResults().empty()) {
553  // Analyze/vectorize reduction.
554  if (yield->getNumOperands() != 1)
555  return false;
556  Value red = yield->getOperand(0);
557  Value iter = forOp.getRegionIterArg(0);
558  vector::CombiningKind kind;
559  Value vrhs;
560  if (isVectorizableReduction(red, iter, kind) &&
561  vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
562  if (codegen) {
563  Value partial = forOpNew.getResult(0);
564  Value vpass = genVectorInvariantValue(rewriter, vl, iter);
565  Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
566  rewriter.create<scf::YieldOp>(loc, vred);
567  rewriter.setInsertionPointAfter(forOpNew);
568  Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
569  // Now do some relinking (last one is not completely type safe
570  // but all bad ones are removed right away). This also folds away
571  // nop broadcast operations.
572  rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
573  rewriter.replaceAllUsesWith(forOp.getInductionVar(),
574  forOpNew.getInductionVar());
575  rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
576  forOpNew.getRegionIterArg(0));
577  rewriter.eraseOp(forOp);
578  }
579  return true;
580  }
581  } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
582  // Analyze/vectorize store operation.
583  auto subs = store.getIndices();
584  SmallVector<Value> idxs;
585  Value rhs = store.getValue();
586  Value vrhs;
587  if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
588  vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
589  if (codegen) {
590  genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
591  rewriter.eraseOp(store);
592  }
593  return true;
594  }
595  }
596 
597  assert(!codegen && "cannot call codegen when analysis failed");
598  return false;
599 }
600 
601 /// Basic for-loop vectorizer.
602 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
603 public:
605 
606  ForOpRewriter(MLIRContext *context, unsigned vectorLength,
607  bool enableVLAVectorization, bool enableSIMDIndex32)
608  : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
609  enableSIMDIndex32} {}
610 
611  LogicalResult matchAndRewrite(scf::ForOp op,
612  PatternRewriter &rewriter) const override {
613  // Check for single block, unit-stride for-loop that is generated by
614  // sparsifier, which means no data dependence analysis is required,
615  // and its loop-body is very restricted in form.
616  if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
618  return failure();
619  // Analyze (!codegen) and rewrite (codegen) loop-body.
620  if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
621  vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
622  return success();
623  return failure();
624  }
625 
626 private:
627  const VL vl;
628 };
629 
630 /// Reduction chain cleanup.
631 /// v = for { }
632 /// s = vsum(v) v = for { }
633 /// u = expand(s) -> for (v) { }
634 /// for (u) { }
635 template <typename VectorOp>
636 struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
637 public:
639 
640  LogicalResult matchAndRewrite(VectorOp op,
641  PatternRewriter &rewriter) const override {
642  Value inp = op.getSource();
643  if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
644  if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
645  if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
646  rewriter.replaceOp(op, redOp.getVector());
647  return success();
648  }
649  }
650  }
651  return failure();
652  }
653 };
654 
655 } // namespace
656 
657 //===----------------------------------------------------------------------===//
658 // Public method for populating vectorization rules.
659 //===----------------------------------------------------------------------===//
660 
661 /// Populates the given patterns list with vectorization rules.
663  unsigned vectorLength,
664  bool enableVLAVectorization,
665  bool enableSIMDIndex32) {
666  assert(vectorLength > 0);
667  patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
668  enableVLAVectorization, enableSIMDIndex32);
669  patterns.add<ReducChainRewriter<vector::InsertElementOp>,
670  ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
671 }
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
#define UNAOP(xxx)
#define BINOP(xxx)
#define TYPEDUNAOP(xxx)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
OpListType & getOperations()
Definition: Block.h:135
Operation & front()
Definition: Block.h:151
reverse_iterator rbegin()
Definition: Block.h:143
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:379
IntegerType getI64Type()
Definition: Builders.cpp:89
IntegerType getI32Type()
Definition: Builders.cpp:87
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:375
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:77
IndexType getIndexType()
Definition: Builders.cpp:75
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:523
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:415
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
unsigned getNumOperands()
Definition: Operation.h:341
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:234
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:312
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:323
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:359
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:135
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358