MLIR  18.0.0git
SparseVectorization.cpp
Go to the documentation of this file.
1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts loops generated by the sparsifier into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparsifier can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "CodegenUtils.h"
20 #include "LoopEmitter.h"
21 
30 #include "mlir/IR/Matchers.h"
31 
32 using namespace mlir;
33 using namespace mlir::sparse_tensor;
34 
35 namespace {
36 
37 /// Target SIMD properties:
38 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
39 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve)
40 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
41 struct VL {
42  unsigned vectorLength;
43  bool enableVLAVectorization;
44  bool enableSIMDIndex32;
45 };
46 
47 /// Helper test for invariant value (defined outside given block).
48 static bool isInvariantValue(Value val, Block *block) {
49  return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
50 }
51 
52 /// Helper test for invariant argument (defined outside given block).
53 static bool isInvariantArg(BlockArgument arg, Block *block) {
54  return arg.getOwner() != block;
55 }
56 
57 /// Constructs vector type for element type.
58 static VectorType vectorType(VL vl, Type etp) {
59  return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
60 }
61 
62 /// Constructs vector type from a memref value.
63 static VectorType vectorType(VL vl, Value mem) {
64  return vectorType(vl, getMemRefType(mem).getElementType());
65 }
66 
67 /// Constructs vector iteration mask.
68 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
69  Value iv, Value lo, Value hi, Value step) {
70  VectorType mtp = vectorType(vl, rewriter.getI1Type());
71  // Special case if the vector length evenly divides the trip count (for
72  // example, "for i = 0, 128, 16"). A constant all-true mask is generated
73  // so that all subsequent masked memory operations are immediately folded
74  // into unconditional memory operations.
75  IntegerAttr loInt, hiInt, stepInt;
76  if (matchPattern(lo, m_Constant(&loInt)) &&
77  matchPattern(hi, m_Constant(&hiInt)) &&
78  matchPattern(step, m_Constant(&stepInt))) {
79  if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
80  Value trueVal = constantI1(rewriter, loc, true);
81  return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
82  }
83  }
84  // Otherwise, generate a vector mask that avoids overrunning the upperbound
85  // during vector execution. Here we rely on subsequent loop optimizations to
86  // avoid executing the mask in all iterations, for example, by splitting the
87  // loop into an unconditional vector loop and a scalar cleanup loop.
88  auto min = AffineMap::get(
89  /*dimCount=*/2, /*symbolCount=*/1,
90  {rewriter.getAffineSymbolExpr(0),
91  rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
92  rewriter.getContext());
93  Value end = rewriter.createOrFold<affine::AffineMinOp>(
94  loc, min, ValueRange{hi, iv, step});
95  return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
96 }
97 
98 /// Generates a vectorized invariant. Here we rely on subsequent loop
99 /// optimizations to hoist the invariant broadcast out of the vector loop.
100 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
101  Value val) {
102  VectorType vtp = vectorType(vl, val.getType());
103  return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
104 }
105 
106 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
107 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
108 /// that the sparsifier can only generate indirect loads in
109 /// the last index, i.e. back().
110 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
111  Value mem, ArrayRef<Value> idxs, Value vmask) {
112  VectorType vtp = vectorType(vl, mem);
113  Value pass = constantZero(rewriter, loc, vtp);
114  if (llvm::isa<VectorType>(idxs.back().getType())) {
115  SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
116  Value indexVec = idxs.back();
117  scalarArgs.back() = constantIndex(rewriter, loc, 0);
118  return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs,
119  indexVec, vmask, pass);
120  }
121  return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask,
122  pass);
123 }
124 
125 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
126 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
127 /// that the sparsifier can only generate indirect stores in
128 /// the last index, i.e. back().
129 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
130  ArrayRef<Value> idxs, Value vmask, Value rhs) {
131  if (llvm::isa<VectorType>(idxs.back().getType())) {
132  SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
133  Value indexVec = idxs.back();
134  scalarArgs.back() = constantIndex(rewriter, loc, 0);
135  rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask,
136  rhs);
137  return;
138  }
139  rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs);
140 }
141 
142 /// Detects a vectorizable reduction operations and returns the
143 /// combining kind of reduction on success in `kind`.
144 static bool isVectorizableReduction(Value red, Value iter,
145  vector::CombiningKind &kind) {
146  if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
147  kind = vector::CombiningKind::ADD;
148  return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
149  }
150  if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
151  kind = vector::CombiningKind::ADD;
152  return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
153  }
154  if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
155  kind = vector::CombiningKind::ADD;
156  return subf->getOperand(0) == iter;
157  }
158  if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
159  kind = vector::CombiningKind::ADD;
160  return subi->getOperand(0) == iter;
161  }
162  if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
163  kind = vector::CombiningKind::MUL;
164  return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
165  }
166  if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
167  kind = vector::CombiningKind::MUL;
168  return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
169  }
170  if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
171  kind = vector::CombiningKind::AND;
172  return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
173  }
174  if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
175  kind = vector::CombiningKind::OR;
176  return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
177  }
178  if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
179  kind = vector::CombiningKind::XOR;
180  return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
181  }
182  return false;
183 }
184 
185 /// Generates an initial value for a vector reduction, following the scheme
186 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
187 /// initial scalar value is correctly embedded in the vector reduction value,
188 /// and a straightforward horizontal reduction will complete the operation.
189 /// Value 'r' denotes the initial value of the reduction outside the loop.
190 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
191  Value red, Value iter, Value r,
192  VectorType vtp) {
193  vector::CombiningKind kind;
194  if (!isVectorizableReduction(red, iter, kind))
195  llvm_unreachable("unknown reduction");
196  switch (kind) {
197  case vector::CombiningKind::ADD:
198  case vector::CombiningKind::XOR:
199  // Initialize reduction vector to: | 0 | .. | 0 | r |
200  return rewriter.create<vector::InsertElementOp>(
201  loc, r, constantZero(rewriter, loc, vtp),
202  constantIndex(rewriter, loc, 0));
203  case vector::CombiningKind::MUL:
204  // Initialize reduction vector to: | 1 | .. | 1 | r |
205  return rewriter.create<vector::InsertElementOp>(
206  loc, r, constantOne(rewriter, loc, vtp),
207  constantIndex(rewriter, loc, 0));
208  case vector::CombiningKind::AND:
209  case vector::CombiningKind::OR:
210  // Initialize reduction vector to: | r | .. | r | r |
211  return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
212  default:
213  break;
214  }
215  llvm_unreachable("unknown reduction kind");
216 }
217 
218 /// This method is called twice to analyze and rewrite the given subscripts.
219 /// The first call (!codegen) does the analysis. Then, on success, the second
220 /// call (codegen) yields the proper vector form in the output parameter
221 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
222 /// stay in sync. Note that the analyis part is simple because the sparsifier
223 /// only generates relatively simple subscript expressions.
224 ///
225 /// See https://llvm.org/docs/GetElementPtr.html for some background on
226 /// the complications described below.
227 ///
228 /// We need to generate a position/coordinate load from the sparse storage
229 /// scheme. Narrower data types need to be zero extended before casting
230 /// the value into the `index` type used for looping and indexing.
231 ///
232 /// For the scalar case, subscripts simply zero extend narrower indices
233 /// into 64-bit values before casting to an index type without a performance
234 /// penalty. Indices that already are 64-bit, in theory, cannot express the
235 /// full range since the LLVM backend defines addressing in terms of an
236 /// unsigned pointer/signed index pair.
237 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
238  VL vl, ValueRange subs, bool codegen,
239  Value vmask, SmallVectorImpl<Value> &idxs) {
240  unsigned d = 0;
241  unsigned dim = subs.size();
242  Block *block = &forOp.getRegion().front();
243  for (auto sub : subs) {
244  bool innermost = ++d == dim;
245  // Invariant subscripts in outer dimensions simply pass through.
246  // Note that we rely on LICM to hoist loads where all subscripts
247  // are invariant in the innermost loop.
248  // Example:
249  // a[inv][i] for inv
250  if (isInvariantValue(sub, block)) {
251  if (innermost)
252  return false;
253  if (codegen)
254  idxs.push_back(sub);
255  continue; // success so far
256  }
257  // Invariant block arguments (including outer loop indices) in outer
258  // dimensions simply pass through. Direct loop indices in the
259  // innermost loop simply pass through as well.
260  // Example:
261  // a[i][j] for both i and j
262  if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
263  if (isInvariantArg(arg, block) == innermost)
264  return false;
265  if (codegen)
266  idxs.push_back(sub);
267  continue; // success so far
268  }
269  // Look under the hood of casting.
270  auto cast = sub;
271  while (true) {
272  if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
273  cast = icast->getOperand(0);
274  else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
275  cast = ecast->getOperand(0);
276  else
277  break;
278  }
279  // Since the index vector is used in a subsequent gather/scatter
280  // operations, which effectively defines an unsigned pointer + signed
281  // index, we must zero extend the vector to an index width. For 8-bit
282  // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
283  // zero extending the elements into 64-bit loses some performance since
284  // the 32-bit indexed gather/scatter is more efficient than the 64-bit
285  // index variant (if the negative 32-bit index space is unused, the
286  // enableSIMDIndex32 flag can preserve this performance). For 64-bit
287  // values, there is no good way to state that the indices are unsigned,
288  // which creates the potential of incorrect address calculations in the
289  // unlikely case we need such extremely large offsets.
290  // Example:
291  // a[ ind[i] ]
292  if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
293  if (!innermost)
294  return false;
295  if (codegen) {
296  SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
297  Location loc = forOp.getLoc();
298  Value vload =
299  genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
300  Type etp = llvm::cast<VectorType>(vload.getType()).getElementType();
301  if (!llvm::isa<IndexType>(etp)) {
302  if (etp.getIntOrFloatBitWidth() < 32)
303  vload = rewriter.create<arith::ExtUIOp>(
304  loc, vectorType(vl, rewriter.getI32Type()), vload);
305  else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
306  vload = rewriter.create<arith::ExtUIOp>(
307  loc, vectorType(vl, rewriter.getI64Type()), vload);
308  }
309  idxs.push_back(vload);
310  }
311  continue; // success so far
312  }
313  // Address calculation 'i = add inv, idx' (after LICM).
314  // Example:
315  // a[base + i]
316  if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
317  Value inv = load.getOperand(0);
318  Value idx = load.getOperand(1);
319  if (isInvariantValue(inv, block)) {
320  if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
321  if (isInvariantArg(arg, block) || !innermost)
322  return false;
323  if (codegen)
324  idxs.push_back(
325  rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
326  continue; // success so far
327  }
328  }
329  }
330  return false;
331  }
332  return true;
333 }
334 
335 #define UNAOP(xxx) \
336  if (isa<xxx>(def)) { \
337  if (codegen) \
338  vexp = rewriter.create<xxx>(loc, vx); \
339  return true; \
340  }
341 
342 #define TYPEDUNAOP(xxx) \
343  if (auto x = dyn_cast<xxx>(def)) { \
344  if (codegen) { \
345  VectorType vtp = vectorType(vl, x.getType()); \
346  vexp = rewriter.create<xxx>(loc, vtp, vx); \
347  } \
348  return true; \
349  }
350 
351 #define BINOP(xxx) \
352  if (isa<xxx>(def)) { \
353  if (codegen) \
354  vexp = rewriter.create<xxx>(loc, vx, vy); \
355  return true; \
356  }
357 
358 /// This method is called twice to analyze and rewrite the given expression.
359 /// The first call (!codegen) does the analysis. Then, on success, the second
360 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
361 /// This mechanism ensures that analysis and rewriting code stay in sync. Note
362 /// that the analyis part is simple because the sparsifier only generates
363 /// relatively simple expressions inside the for-loops.
364 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
365  Value exp, bool codegen, Value vmask, Value &vexp) {
366  Location loc = forOp.getLoc();
367  // Reject unsupported types.
368  if (!VectorType::isValidElementType(exp.getType()))
369  return false;
370  // A block argument is invariant/reduction/index.
371  if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
372  if (arg == forOp.getInductionVar()) {
373  // We encountered a single, innermost index inside the computation,
374  // such as a[i] = i, which must convert to [i, i+1, ...].
375  if (codegen) {
376  VectorType vtp = vectorType(vl, arg.getType());
377  Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg);
378  Value incr;
379  if (vl.enableVLAVectorization) {
380  Type stepvty = vectorType(vl, rewriter.getI64Type());
381  Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty);
382  incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv);
383  } else {
384  SmallVector<APInt> integers;
385  for (unsigned i = 0, l = vl.vectorLength; i < l; i++)
386  integers.push_back(APInt(/*width=*/64, i));
387  auto values = DenseElementsAttr::get(vtp, integers);
388  incr = rewriter.create<arith::ConstantOp>(loc, vtp, values);
389  }
390  vexp = rewriter.create<arith::AddIOp>(loc, veci, incr);
391  }
392  return true;
393  }
394  // An invariant or reduction. In both cases, we treat this as an
395  // invariant value, and rely on later replacing and folding to
396  // construct a proper reduction chain for the latter case.
397  if (codegen)
398  vexp = genVectorInvariantValue(rewriter, vl, exp);
399  return true;
400  }
401  // Something defined outside the loop-body is invariant.
402  Operation *def = exp.getDefiningOp();
403  Block *block = &forOp.getRegion().front();
404  if (def->getBlock() != block) {
405  if (codegen)
406  vexp = genVectorInvariantValue(rewriter, vl, exp);
407  return true;
408  }
409  // Proper load operations. These are either values involved in the
410  // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
411  // or coordinate values inside the computation that are now fetched from
412  // the sparse storage coordinates arrays, such as a[i] = i becomes
413  // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
414  // and 'hi = lo + vl - 1'.
415  if (auto load = dyn_cast<memref::LoadOp>(def)) {
416  auto subs = load.getIndices();
417  SmallVector<Value> idxs;
418  if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
419  if (codegen)
420  vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
421  return true;
422  }
423  return false;
424  }
425  // Inside loop-body unary and binary operations. Note that it would be
426  // nicer if we could somehow test and build the operations in a more
427  // concise manner than just listing them all (although this way we know
428  // for certain that they can vectorize).
429  //
430  // TODO: avoid visiting CSEs multiple times
431  //
432  if (def->getNumOperands() == 1) {
433  Value vx;
434  if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
435  vx)) {
436  UNAOP(math::AbsFOp)
437  UNAOP(math::AbsIOp)
438  UNAOP(math::CeilOp)
439  UNAOP(math::FloorOp)
440  UNAOP(math::SqrtOp)
441  UNAOP(math::ExpM1Op)
442  UNAOP(math::Log1pOp)
443  UNAOP(math::SinOp)
444  UNAOP(math::TanhOp)
445  UNAOP(arith::NegFOp)
446  TYPEDUNAOP(arith::TruncFOp)
447  TYPEDUNAOP(arith::ExtFOp)
448  TYPEDUNAOP(arith::FPToSIOp)
449  TYPEDUNAOP(arith::FPToUIOp)
450  TYPEDUNAOP(arith::SIToFPOp)
451  TYPEDUNAOP(arith::UIToFPOp)
452  TYPEDUNAOP(arith::ExtSIOp)
453  TYPEDUNAOP(arith::ExtUIOp)
454  TYPEDUNAOP(arith::IndexCastOp)
455  TYPEDUNAOP(arith::TruncIOp)
456  TYPEDUNAOP(arith::BitcastOp)
457  // TODO: complex?
458  }
459  } else if (def->getNumOperands() == 2) {
460  Value vx, vy;
461  if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
462  vx) &&
463  vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
464  vy)) {
465  // We only accept shift-by-invariant (where the same shift factor applies
466  // to all packed elements). In the vector dialect, this is still
467  // represented with an expanded vector at the right-hand-side, however,
468  // so that we do not have to special case the code generation.
469  if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
470  isa<arith::ShRSIOp>(def)) {
471  Value shiftFactor = def->getOperand(1);
472  if (!isInvariantValue(shiftFactor, block))
473  return false;
474  }
475  // Generate code.
476  BINOP(arith::MulFOp)
477  BINOP(arith::MulIOp)
478  BINOP(arith::DivFOp)
479  BINOP(arith::DivSIOp)
480  BINOP(arith::DivUIOp)
481  BINOP(arith::AddFOp)
482  BINOP(arith::AddIOp)
483  BINOP(arith::SubFOp)
484  BINOP(arith::SubIOp)
485  BINOP(arith::AndIOp)
486  BINOP(arith::OrIOp)
487  BINOP(arith::XOrIOp)
488  BINOP(arith::ShLIOp)
489  BINOP(arith::ShRUIOp)
490  BINOP(arith::ShRSIOp)
491  // TODO: complex?
492  }
493  }
494  return false;
495 }
496 
497 #undef UNAOP
498 #undef TYPEDUNAOP
499 #undef BINOP
500 
501 /// This method is called twice to analyze and rewrite the given for-loop.
502 /// The first call (!codegen) does the analysis. Then, on success, the second
503 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
504 /// that analysis and rewriting code stay in sync.
505 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
506  bool codegen) {
507  Block &block = forOp.getRegion().front();
508  // For loops with single yield statement (as below) could be generated
509  // when custom reduce is used with unary operation.
510  // for (...)
511  // yield c_0
512  if (block.getOperations().size() <= 1)
513  return false;
514 
515  Location loc = forOp.getLoc();
516  scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
517  auto &last = *++block.rbegin();
518  scf::ForOp forOpNew;
519 
520  // Perform initial set up during codegen (we know that the first analysis
521  // pass was successful). For reductions, we need to construct a completely
522  // new for-loop, since the incoming and outgoing reduction type
523  // changes into SIMD form. For stores, we can simply adjust the stride
524  // and insert in the existing for-loop. In both cases, we set up a vector
525  // mask for all operations which takes care of confining vectors to
526  // the original iteration space (later cleanup loops or other
527  // optimizations can take care of those).
528  Value vmask;
529  if (codegen) {
530  Value step = constantIndex(rewriter, loc, vl.vectorLength);
531  if (vl.enableVLAVectorization) {
532  Value vscale =
533  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
534  step = rewriter.create<arith::MulIOp>(loc, vscale, step);
535  }
536  if (!yield.getResults().empty()) {
537  Value init = forOp.getInitArgs()[0];
538  VectorType vtp = vectorType(vl, init.getType());
539  Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
540  forOp.getRegionIterArg(0), init, vtp);
541  forOpNew = rewriter.create<scf::ForOp>(
542  loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
543  forOpNew->setAttr(
545  forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
546  rewriter.setInsertionPointToStart(forOpNew.getBody());
547  } else {
548  rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); });
549  rewriter.setInsertionPoint(yield);
550  }
551  vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
552  forOp.getLowerBound(), forOp.getUpperBound(), step);
553  }
554 
555  // Sparse for-loops either are terminated by a non-empty yield operation
556  // (reduction loop) or otherwise by a store operation (pararallel loop).
557  if (!yield.getResults().empty()) {
558  // Analyze/vectorize reduction.
559  if (yield->getNumOperands() != 1)
560  return false;
561  Value red = yield->getOperand(0);
562  Value iter = forOp.getRegionIterArg(0);
563  vector::CombiningKind kind;
564  Value vrhs;
565  if (isVectorizableReduction(red, iter, kind) &&
566  vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
567  if (codegen) {
568  Value partial = forOpNew.getResult(0);
569  Value vpass = genVectorInvariantValue(rewriter, vl, iter);
570  Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
571  rewriter.create<scf::YieldOp>(loc, vred);
572  rewriter.setInsertionPointAfter(forOpNew);
573  Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
574  // Now do some relinking (last one is not completely type safe
575  // but all bad ones are removed right away). This also folds away
576  // nop broadcast operations.
577  rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
578  rewriter.replaceAllUsesWith(forOp.getInductionVar(),
579  forOpNew.getInductionVar());
580  rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
581  forOpNew.getRegionIterArg(0));
582  rewriter.eraseOp(forOp);
583  }
584  return true;
585  }
586  } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
587  // Analyze/vectorize store operation.
588  auto subs = store.getIndices();
589  SmallVector<Value> idxs;
590  Value rhs = store.getValue();
591  Value vrhs;
592  if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
593  vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
594  if (codegen) {
595  genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
596  rewriter.eraseOp(store);
597  }
598  return true;
599  }
600  }
601 
602  assert(!codegen && "cannot call codegen when analysis failed");
603  return false;
604 }
605 
606 /// Basic for-loop vectorizer.
607 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
608 public:
610 
611  ForOpRewriter(MLIRContext *context, unsigned vectorLength,
612  bool enableVLAVectorization, bool enableSIMDIndex32)
613  : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
614  enableSIMDIndex32} {}
615 
616  LogicalResult matchAndRewrite(scf::ForOp op,
617  PatternRewriter &rewriter) const override {
618  // Check for single block, unit-stride for-loop that is generated by
619  // sparsifier, which means no data dependence analysis is required,
620  // and its loop-body is very restricted in form.
621  if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
623  return failure();
624  // Analyze (!codegen) and rewrite (codegen) loop-body.
625  if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
626  vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
627  return success();
628  return failure();
629  }
630 
631 private:
632  const VL vl;
633 };
634 
635 /// Reduction chain cleanup.
636 /// v = for { }
637 /// s = vsum(v) v = for { }
638 /// u = expand(s) -> for (v) { }
639 /// for (u) { }
640 template <typename VectorOp>
641 struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
642 public:
644 
645  LogicalResult matchAndRewrite(VectorOp op,
646  PatternRewriter &rewriter) const override {
647  Value inp = op.getSource();
648  if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
649  if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
650  if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
651  rewriter.replaceOp(op, redOp.getVector());
652  return success();
653  }
654  }
655  }
656  return failure();
657  }
658 };
659 
660 } // namespace
661 
662 //===----------------------------------------------------------------------===//
663 // Public method for populating vectorization rules.
664 //===----------------------------------------------------------------------===//
665 
666 /// Populates the given patterns list with vectorization rules.
668  unsigned vectorLength,
669  bool enableVLAVectorization,
670  bool enableSIMDIndex32) {
671  assert(vectorLength > 0);
672  patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
673  enableVLAVectorization, enableSIMDIndex32);
674  patterns.add<ReducChainRewriter<vector::InsertElementOp>,
675  ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
676 }
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
#define UNAOP(xxx)
#define BINOP(xxx)
#define TYPEDUNAOP(xxx)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Definition: Value.h:315
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
OpListType & getOperations()
Definition: Block.h:130
Operation & front()
Definition: Block.h:146
reverse_iterator rbegin()
Definition: Block.h:138
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:357
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:353
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:538
unsigned getNumOperands()
Definition: Operation.h:341
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:560
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:282
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:361
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:339
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:350
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:386
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:82
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357