MLIR  19.0.0git
SparseVectorization.cpp
Go to the documentation of this file.
1 //===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts loops generated by the sparsifier into a form that
10 // can exploit SIMD instructions of the target architecture. Note that this pass
11 // ensures the sparsifier can generate efficient SIMD (including ArmSVE
12 // support) with proper separation of concerns as far as sparsification and
13 // vectorization is concerned. However, this pass is not the final abstraction
14 // level we want, and not the general vectorizer we want either. It forms a good
15 // stepping stone for incremental future improvements though.
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "Utils/CodegenUtils.h"
20 #include "Utils/LoopEmitter.h"
21 
30 #include "mlir/IR/Matchers.h"
31 
32 using namespace mlir;
33 using namespace mlir::sparse_tensor;
34 
35 namespace {
36 
37 /// Target SIMD properties:
38 /// vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
39 /// enableVLAVectorization: enables scalable vectors (viz. ARMSve)
40 /// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
41 struct VL {
42  unsigned vectorLength;
43  bool enableVLAVectorization;
44  bool enableSIMDIndex32;
45 };
46 
47 /// Helper test for invariant value (defined outside given block).
48 static bool isInvariantValue(Value val, Block *block) {
49  return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
50 }
51 
52 /// Helper test for invariant argument (defined outside given block).
53 static bool isInvariantArg(BlockArgument arg, Block *block) {
54  return arg.getOwner() != block;
55 }
56 
57 /// Constructs vector type for element type.
58 static VectorType vectorType(VL vl, Type etp) {
59  return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
60 }
61 
62 /// Constructs vector type from a memref value.
63 static VectorType vectorType(VL vl, Value mem) {
64  return vectorType(vl, getMemRefType(mem).getElementType());
65 }
66 
67 /// Constructs vector iteration mask.
68 static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
69  Value iv, Value lo, Value hi, Value step) {
70  VectorType mtp = vectorType(vl, rewriter.getI1Type());
71  // Special case if the vector length evenly divides the trip count (for
72  // example, "for i = 0, 128, 16"). A constant all-true mask is generated
73  // so that all subsequent masked memory operations are immediately folded
74  // into unconditional memory operations.
75  IntegerAttr loInt, hiInt, stepInt;
76  if (matchPattern(lo, m_Constant(&loInt)) &&
77  matchPattern(hi, m_Constant(&hiInt)) &&
78  matchPattern(step, m_Constant(&stepInt))) {
79  if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
80  Value trueVal = constantI1(rewriter, loc, true);
81  return rewriter.create<vector::BroadcastOp>(loc, mtp, trueVal);
82  }
83  }
84  // Otherwise, generate a vector mask that avoids overrunning the upperbound
85  // during vector execution. Here we rely on subsequent loop optimizations to
86  // avoid executing the mask in all iterations, for example, by splitting the
87  // loop into an unconditional vector loop and a scalar cleanup loop.
88  auto min = AffineMap::get(
89  /*dimCount=*/2, /*symbolCount=*/1,
90  {rewriter.getAffineSymbolExpr(0),
91  rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
92  rewriter.getContext());
93  Value end = rewriter.createOrFold<affine::AffineMinOp>(
94  loc, min, ValueRange{hi, iv, step});
95  return rewriter.create<vector::CreateMaskOp>(loc, mtp, end);
96 }
97 
98 /// Generates a vectorized invariant. Here we rely on subsequent loop
99 /// optimizations to hoist the invariant broadcast out of the vector loop.
100 static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
101  Value val) {
102  VectorType vtp = vectorType(vl, val.getType());
103  return rewriter.create<vector::BroadcastOp>(val.getLoc(), vtp, val);
104 }
105 
106 /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
107 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
108 /// that the sparsifier can only generate indirect loads in
109 /// the last index, i.e. back().
110 static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
111  Value mem, ArrayRef<Value> idxs, Value vmask) {
112  VectorType vtp = vectorType(vl, mem);
113  Value pass = constantZero(rewriter, loc, vtp);
114  if (llvm::isa<VectorType>(idxs.back().getType())) {
115  SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
116  Value indexVec = idxs.back();
117  scalarArgs.back() = constantIndex(rewriter, loc, 0);
118  return rewriter.create<vector::GatherOp>(loc, vtp, mem, scalarArgs,
119  indexVec, vmask, pass);
120  }
121  return rewriter.create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask,
122  pass);
123 }
124 
125 /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
126 /// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
127 /// that the sparsifier can only generate indirect stores in
128 /// the last index, i.e. back().
129 static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
130  ArrayRef<Value> idxs, Value vmask, Value rhs) {
131  if (llvm::isa<VectorType>(idxs.back().getType())) {
132  SmallVector<Value> scalarArgs(idxs.begin(), idxs.end());
133  Value indexVec = idxs.back();
134  scalarArgs.back() = constantIndex(rewriter, loc, 0);
135  rewriter.create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask,
136  rhs);
137  return;
138  }
139  rewriter.create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs);
140 }
141 
142 /// Detects a vectorizable reduction operations and returns the
143 /// combining kind of reduction on success in `kind`.
144 static bool isVectorizableReduction(Value red, Value iter,
145  vector::CombiningKind &kind) {
146  if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
147  kind = vector::CombiningKind::ADD;
148  return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
149  }
150  if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
151  kind = vector::CombiningKind::ADD;
152  return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
153  }
154  if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
155  kind = vector::CombiningKind::ADD;
156  return subf->getOperand(0) == iter;
157  }
158  if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
159  kind = vector::CombiningKind::ADD;
160  return subi->getOperand(0) == iter;
161  }
162  if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
163  kind = vector::CombiningKind::MUL;
164  return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
165  }
166  if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
167  kind = vector::CombiningKind::MUL;
168  return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
169  }
170  if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
171  kind = vector::CombiningKind::AND;
172  return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
173  }
174  if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
175  kind = vector::CombiningKind::OR;
176  return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
177  }
178  if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
179  kind = vector::CombiningKind::XOR;
180  return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
181  }
182  return false;
183 }
184 
185 /// Generates an initial value for a vector reduction, following the scheme
186 /// given in Chapter 5 of "The Software Vectorization Handbook", where the
187 /// initial scalar value is correctly embedded in the vector reduction value,
188 /// and a straightforward horizontal reduction will complete the operation.
189 /// Value 'r' denotes the initial value of the reduction outside the loop.
190 static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
191  Value red, Value iter, Value r,
192  VectorType vtp) {
193  vector::CombiningKind kind;
194  if (!isVectorizableReduction(red, iter, kind))
195  llvm_unreachable("unknown reduction");
196  switch (kind) {
197  case vector::CombiningKind::ADD:
198  case vector::CombiningKind::XOR:
199  // Initialize reduction vector to: | 0 | .. | 0 | r |
200  return rewriter.create<vector::InsertElementOp>(
201  loc, r, constantZero(rewriter, loc, vtp),
202  constantIndex(rewriter, loc, 0));
203  case vector::CombiningKind::MUL:
204  // Initialize reduction vector to: | 1 | .. | 1 | r |
205  return rewriter.create<vector::InsertElementOp>(
206  loc, r, constantOne(rewriter, loc, vtp),
207  constantIndex(rewriter, loc, 0));
208  case vector::CombiningKind::AND:
209  case vector::CombiningKind::OR:
210  // Initialize reduction vector to: | r | .. | r | r |
211  return rewriter.create<vector::BroadcastOp>(loc, vtp, r);
212  default:
213  break;
214  }
215  llvm_unreachable("unknown reduction kind");
216 }
217 
218 /// This method is called twice to analyze and rewrite the given subscripts.
219 /// The first call (!codegen) does the analysis. Then, on success, the second
220 /// call (codegen) yields the proper vector form in the output parameter
221 /// vector 'idxs'. This mechanism ensures that analysis and rewriting code
222 /// stay in sync. Note that the analyis part is simple because the sparsifier
223 /// only generates relatively simple subscript expressions.
224 ///
225 /// See https://llvm.org/docs/GetElementPtr.html for some background on
226 /// the complications described below.
227 ///
228 /// We need to generate a position/coordinate load from the sparse storage
229 /// scheme. Narrower data types need to be zero extended before casting
230 /// the value into the `index` type used for looping and indexing.
231 ///
232 /// For the scalar case, subscripts simply zero extend narrower indices
233 /// into 64-bit values before casting to an index type without a performance
234 /// penalty. Indices that already are 64-bit, in theory, cannot express the
235 /// full range since the LLVM backend defines addressing in terms of an
236 /// unsigned pointer/signed index pair.
237 static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
238  VL vl, ValueRange subs, bool codegen,
239  Value vmask, SmallVectorImpl<Value> &idxs) {
240  unsigned d = 0;
241  unsigned dim = subs.size();
242  Block *block = &forOp.getRegion().front();
243  for (auto sub : subs) {
244  bool innermost = ++d == dim;
245  // Invariant subscripts in outer dimensions simply pass through.
246  // Note that we rely on LICM to hoist loads where all subscripts
247  // are invariant in the innermost loop.
248  // Example:
249  // a[inv][i] for inv
250  if (isInvariantValue(sub, block)) {
251  if (innermost)
252  return false;
253  if (codegen)
254  idxs.push_back(sub);
255  continue; // success so far
256  }
257  // Invariant block arguments (including outer loop indices) in outer
258  // dimensions simply pass through. Direct loop indices in the
259  // innermost loop simply pass through as well.
260  // Example:
261  // a[i][j] for both i and j
262  if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
263  if (isInvariantArg(arg, block) == innermost)
264  return false;
265  if (codegen)
266  idxs.push_back(sub);
267  continue; // success so far
268  }
269  // Look under the hood of casting.
270  auto cast = sub;
271  while (true) {
272  if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
273  cast = icast->getOperand(0);
274  else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
275  cast = ecast->getOperand(0);
276  else
277  break;
278  }
279  // Since the index vector is used in a subsequent gather/scatter
280  // operations, which effectively defines an unsigned pointer + signed
281  // index, we must zero extend the vector to an index width. For 8-bit
282  // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
283  // zero extending the elements into 64-bit loses some performance since
284  // the 32-bit indexed gather/scatter is more efficient than the 64-bit
285  // index variant (if the negative 32-bit index space is unused, the
286  // enableSIMDIndex32 flag can preserve this performance). For 64-bit
287  // values, there is no good way to state that the indices are unsigned,
288  // which creates the potential of incorrect address calculations in the
289  // unlikely case we need such extremely large offsets.
290  // Example:
291  // a[ ind[i] ]
292  if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
293  if (!innermost)
294  return false;
295  if (codegen) {
296  SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
297  Location loc = forOp.getLoc();
298  Value vload =
299  genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
300  Type etp = llvm::cast<VectorType>(vload.getType()).getElementType();
301  if (!llvm::isa<IndexType>(etp)) {
302  if (etp.getIntOrFloatBitWidth() < 32)
303  vload = rewriter.create<arith::ExtUIOp>(
304  loc, vectorType(vl, rewriter.getI32Type()), vload);
305  else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
306  vload = rewriter.create<arith::ExtUIOp>(
307  loc, vectorType(vl, rewriter.getI64Type()), vload);
308  }
309  idxs.push_back(vload);
310  }
311  continue; // success so far
312  }
313  // Address calculation 'i = add inv, idx' (after LICM).
314  // Example:
315  // a[base + i]
316  if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
317  Value inv = load.getOperand(0);
318  Value idx = load.getOperand(1);
319  // Swap non-invariant.
320  if (!isInvariantValue(inv, block)) {
321  inv = idx;
322  idx = load.getOperand(0);
323  }
324  // Inspect.
325  if (isInvariantValue(inv, block)) {
326  if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
327  if (isInvariantArg(arg, block) || !innermost)
328  return false;
329  if (codegen)
330  idxs.push_back(
331  rewriter.create<arith::AddIOp>(forOp.getLoc(), inv, idx));
332  continue; // success so far
333  }
334  }
335  }
336  return false;
337  }
338  return true;
339 }
340 
341 #define UNAOP(xxx) \
342  if (isa<xxx>(def)) { \
343  if (codegen) \
344  vexp = rewriter.create<xxx>(loc, vx); \
345  return true; \
346  }
347 
348 #define TYPEDUNAOP(xxx) \
349  if (auto x = dyn_cast<xxx>(def)) { \
350  if (codegen) { \
351  VectorType vtp = vectorType(vl, x.getType()); \
352  vexp = rewriter.create<xxx>(loc, vtp, vx); \
353  } \
354  return true; \
355  }
356 
357 #define BINOP(xxx) \
358  if (isa<xxx>(def)) { \
359  if (codegen) \
360  vexp = rewriter.create<xxx>(loc, vx, vy); \
361  return true; \
362  }
363 
364 /// This method is called twice to analyze and rewrite the given expression.
365 /// The first call (!codegen) does the analysis. Then, on success, the second
366 /// call (codegen) yields the proper vector form in the output parameter 'vexp'.
367 /// This mechanism ensures that analysis and rewriting code stay in sync. Note
368 /// that the analyis part is simple because the sparsifier only generates
369 /// relatively simple expressions inside the for-loops.
370 static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
371  Value exp, bool codegen, Value vmask, Value &vexp) {
372  Location loc = forOp.getLoc();
373  // Reject unsupported types.
374  if (!VectorType::isValidElementType(exp.getType()))
375  return false;
376  // A block argument is invariant/reduction/index.
377  if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
378  if (arg == forOp.getInductionVar()) {
379  // We encountered a single, innermost index inside the computation,
380  // such as a[i] = i, which must convert to [i, i+1, ...].
381  if (codegen) {
382  VectorType vtp = vectorType(vl, arg.getType());
383  Value veci = rewriter.create<vector::BroadcastOp>(loc, vtp, arg);
384  Value incr;
385  if (vl.enableVLAVectorization) {
386  Type stepvty = vectorType(vl, rewriter.getI64Type());
387  Value stepv = rewriter.create<LLVM::StepVectorOp>(loc, stepvty);
388  incr = rewriter.create<arith::IndexCastOp>(loc, vtp, stepv);
389  } else {
390  SmallVector<APInt> integers;
391  for (unsigned i = 0, l = vl.vectorLength; i < l; i++)
392  integers.push_back(APInt(/*width=*/64, i));
393  auto values = DenseElementsAttr::get(vtp, integers);
394  incr = rewriter.create<arith::ConstantOp>(loc, vtp, values);
395  }
396  vexp = rewriter.create<arith::AddIOp>(loc, veci, incr);
397  }
398  return true;
399  }
400  // An invariant or reduction. In both cases, we treat this as an
401  // invariant value, and rely on later replacing and folding to
402  // construct a proper reduction chain for the latter case.
403  if (codegen)
404  vexp = genVectorInvariantValue(rewriter, vl, exp);
405  return true;
406  }
407  // Something defined outside the loop-body is invariant.
408  Operation *def = exp.getDefiningOp();
409  Block *block = &forOp.getRegion().front();
410  if (def->getBlock() != block) {
411  if (codegen)
412  vexp = genVectorInvariantValue(rewriter, vl, exp);
413  return true;
414  }
415  // Proper load operations. These are either values involved in the
416  // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
417  // or coordinate values inside the computation that are now fetched from
418  // the sparse storage coordinates arrays, such as a[i] = i becomes
419  // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
420  // and 'hi = lo + vl - 1'.
421  if (auto load = dyn_cast<memref::LoadOp>(def)) {
422  auto subs = load.getIndices();
423  SmallVector<Value> idxs;
424  if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
425  if (codegen)
426  vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
427  return true;
428  }
429  return false;
430  }
431  // Inside loop-body unary and binary operations. Note that it would be
432  // nicer if we could somehow test and build the operations in a more
433  // concise manner than just listing them all (although this way we know
434  // for certain that they can vectorize).
435  //
436  // TODO: avoid visiting CSEs multiple times
437  //
438  if (def->getNumOperands() == 1) {
439  Value vx;
440  if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
441  vx)) {
442  UNAOP(math::AbsFOp)
443  UNAOP(math::AbsIOp)
444  UNAOP(math::CeilOp)
445  UNAOP(math::FloorOp)
446  UNAOP(math::SqrtOp)
447  UNAOP(math::ExpM1Op)
448  UNAOP(math::Log1pOp)
449  UNAOP(math::SinOp)
450  UNAOP(math::TanhOp)
451  UNAOP(arith::NegFOp)
452  TYPEDUNAOP(arith::TruncFOp)
453  TYPEDUNAOP(arith::ExtFOp)
454  TYPEDUNAOP(arith::FPToSIOp)
455  TYPEDUNAOP(arith::FPToUIOp)
456  TYPEDUNAOP(arith::SIToFPOp)
457  TYPEDUNAOP(arith::UIToFPOp)
458  TYPEDUNAOP(arith::ExtSIOp)
459  TYPEDUNAOP(arith::ExtUIOp)
460  TYPEDUNAOP(arith::IndexCastOp)
461  TYPEDUNAOP(arith::TruncIOp)
462  TYPEDUNAOP(arith::BitcastOp)
463  // TODO: complex?
464  }
465  } else if (def->getNumOperands() == 2) {
466  Value vx, vy;
467  if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
468  vx) &&
469  vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
470  vy)) {
471  // We only accept shift-by-invariant (where the same shift factor applies
472  // to all packed elements). In the vector dialect, this is still
473  // represented with an expanded vector at the right-hand-side, however,
474  // so that we do not have to special case the code generation.
475  if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
476  isa<arith::ShRSIOp>(def)) {
477  Value shiftFactor = def->getOperand(1);
478  if (!isInvariantValue(shiftFactor, block))
479  return false;
480  }
481  // Generate code.
482  BINOP(arith::MulFOp)
483  BINOP(arith::MulIOp)
484  BINOP(arith::DivFOp)
485  BINOP(arith::DivSIOp)
486  BINOP(arith::DivUIOp)
487  BINOP(arith::AddFOp)
488  BINOP(arith::AddIOp)
489  BINOP(arith::SubFOp)
490  BINOP(arith::SubIOp)
491  BINOP(arith::AndIOp)
492  BINOP(arith::OrIOp)
493  BINOP(arith::XOrIOp)
494  BINOP(arith::ShLIOp)
495  BINOP(arith::ShRUIOp)
496  BINOP(arith::ShRSIOp)
497  // TODO: complex?
498  }
499  }
500  return false;
501 }
502 
503 #undef UNAOP
504 #undef TYPEDUNAOP
505 #undef BINOP
506 
507 /// This method is called twice to analyze and rewrite the given for-loop.
508 /// The first call (!codegen) does the analysis. Then, on success, the second
509 /// call (codegen) rewriters the IR into vector form. This mechanism ensures
510 /// that analysis and rewriting code stay in sync.
511 static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
512  bool codegen) {
513  Block &block = forOp.getRegion().front();
514  // For loops with single yield statement (as below) could be generated
515  // when custom reduce is used with unary operation.
516  // for (...)
517  // yield c_0
518  if (block.getOperations().size() <= 1)
519  return false;
520 
521  Location loc = forOp.getLoc();
522  scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
523  auto &last = *++block.rbegin();
524  scf::ForOp forOpNew;
525 
526  // Perform initial set up during codegen (we know that the first analysis
527  // pass was successful). For reductions, we need to construct a completely
528  // new for-loop, since the incoming and outgoing reduction type
529  // changes into SIMD form. For stores, we can simply adjust the stride
530  // and insert in the existing for-loop. In both cases, we set up a vector
531  // mask for all operations which takes care of confining vectors to
532  // the original iteration space (later cleanup loops or other
533  // optimizations can take care of those).
534  Value vmask;
535  if (codegen) {
536  Value step = constantIndex(rewriter, loc, vl.vectorLength);
537  if (vl.enableVLAVectorization) {
538  Value vscale =
539  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
540  step = rewriter.create<arith::MulIOp>(loc, vscale, step);
541  }
542  if (!yield.getResults().empty()) {
543  Value init = forOp.getInitArgs()[0];
544  VectorType vtp = vectorType(vl, init.getType());
545  Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
546  forOp.getRegionIterArg(0), init, vtp);
547  forOpNew = rewriter.create<scf::ForOp>(
548  loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
549  forOpNew->setAttr(
551  forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
552  rewriter.setInsertionPointToStart(forOpNew.getBody());
553  } else {
554  rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); });
555  rewriter.setInsertionPoint(yield);
556  }
557  vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
558  forOp.getLowerBound(), forOp.getUpperBound(), step);
559  }
560 
561  // Sparse for-loops either are terminated by a non-empty yield operation
562  // (reduction loop) or otherwise by a store operation (pararallel loop).
563  if (!yield.getResults().empty()) {
564  // Analyze/vectorize reduction.
565  if (yield->getNumOperands() != 1)
566  return false;
567  Value red = yield->getOperand(0);
568  Value iter = forOp.getRegionIterArg(0);
569  vector::CombiningKind kind;
570  Value vrhs;
571  if (isVectorizableReduction(red, iter, kind) &&
572  vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
573  if (codegen) {
574  Value partial = forOpNew.getResult(0);
575  Value vpass = genVectorInvariantValue(rewriter, vl, iter);
576  Value vred = rewriter.create<arith::SelectOp>(loc, vmask, vrhs, vpass);
577  rewriter.create<scf::YieldOp>(loc, vred);
578  rewriter.setInsertionPointAfter(forOpNew);
579  Value vres = rewriter.create<vector::ReductionOp>(loc, kind, partial);
580  // Now do some relinking (last one is not completely type safe
581  // but all bad ones are removed right away). This also folds away
582  // nop broadcast operations.
583  rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
584  rewriter.replaceAllUsesWith(forOp.getInductionVar(),
585  forOpNew.getInductionVar());
586  rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
587  forOpNew.getRegionIterArg(0));
588  rewriter.eraseOp(forOp);
589  }
590  return true;
591  }
592  } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
593  // Analyze/vectorize store operation.
594  auto subs = store.getIndices();
595  SmallVector<Value> idxs;
596  Value rhs = store.getValue();
597  Value vrhs;
598  if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
599  vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
600  if (codegen) {
601  genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
602  rewriter.eraseOp(store);
603  }
604  return true;
605  }
606  }
607 
608  assert(!codegen && "cannot call codegen when analysis failed");
609  return false;
610 }
611 
612 /// Basic for-loop vectorizer.
613 struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
614 public:
616 
617  ForOpRewriter(MLIRContext *context, unsigned vectorLength,
618  bool enableVLAVectorization, bool enableSIMDIndex32)
619  : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
620  enableSIMDIndex32} {}
621 
622  LogicalResult matchAndRewrite(scf::ForOp op,
623  PatternRewriter &rewriter) const override {
624  // Check for single block, unit-stride for-loop that is generated by
625  // sparsifier, which means no data dependence analysis is required,
626  // and its loop-body is very restricted in form.
627  if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
629  return failure();
630  // Analyze (!codegen) and rewrite (codegen) loop-body.
631  if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
632  vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
633  return success();
634  return failure();
635  }
636 
637 private:
638  const VL vl;
639 };
640 
641 /// Reduction chain cleanup.
642 /// v = for { }
643 /// s = vsum(v) v = for { }
644 /// u = expand(s) -> for (v) { }
645 /// for (u) { }
646 template <typename VectorOp>
647 struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
648 public:
650 
651  LogicalResult matchAndRewrite(VectorOp op,
652  PatternRewriter &rewriter) const override {
653  Value inp = op.getSource();
654  if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
655  if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
656  if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
657  rewriter.replaceOp(op, redOp.getVector());
658  return success();
659  }
660  }
661  }
662  return failure();
663  }
664 };
665 
666 } // namespace
667 
668 //===----------------------------------------------------------------------===//
669 // Public method for populating vectorization rules.
670 //===----------------------------------------------------------------------===//
671 
672 /// Populates the given patterns list with vectorization rules.
674  unsigned vectorLength,
675  bool enableVLAVectorization,
676  bool enableSIMDIndex32) {
677  assert(vectorLength > 0);
678  patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
679  enableVLAVectorization, enableSIMDIndex32);
680  patterns.add<ReducChainRewriter<vector::InsertElementOp>,
681  ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
682 }
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:216
#define UNAOP(xxx)
#define BINOP(xxx)
#define TYPEDUNAOP(xxx)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Definition: Value.h:315
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
OpListType & getOperations()
Definition: Block.h:134
Operation & front()
Definition: Block.h:150
reverse_iterator rbegin()
Definition: Block.h:142
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:375
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:73
IndexType getIndexType()
Definition: Builders.cpp:71
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
unsigned getNumOperands()
Definition: Operation.h:341
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Definition: LoopEmitter.h:234
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:334
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Definition: CodegenUtils.h:312
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Definition: CodegenUtils.h:323
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
Definition: CodegenUtils.h:359
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Definition: SparseTensor.h:82
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358