MLIR 22.0.0git
SparseVectorization.cpp
Go to the documentation of this file.
1//===- SparseVectorization.cpp - Vectorization of sparsified loops --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// A pass that converts loops generated by the sparsifier into a form that
10// can exploit SIMD instructions of the target architecture. Note that this pass
11// ensures the sparsifier can generate efficient SIMD (including ArmSVE
12// support) with proper separation of concerns as far as sparsification and
13// vectorization is concerned. However, this pass is not the final abstraction
14// level we want, and not the general vectorizer we want either. It forms a good
15// stepping stone for incremental future improvements though.
16//
17//===----------------------------------------------------------------------===//
18
19#include "Utils/CodegenUtils.h"
20#include "Utils/LoopEmitter.h"
21
30#include "mlir/IR/Matchers.h"
31
32using namespace mlir;
33using namespace mlir::sparse_tensor;
34
35namespace {
36
37/// Target SIMD properties:
38/// vectorLength: # packed data elements (viz. vector<16xf32> has length 16)
39/// enableVLAVectorization: enables scalable vectors (viz. ARMSve)
40/// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency
41struct VL {
42 unsigned vectorLength;
43 bool enableVLAVectorization;
44 bool enableSIMDIndex32;
45};
46
47/// Helper test for invariant value (defined outside given block).
48static bool isInvariantValue(Value val, Block *block) {
49 return val.getDefiningOp() && val.getDefiningOp()->getBlock() != block;
50}
51
52/// Helper test for invariant argument (defined outside given block).
53static bool isInvariantArg(BlockArgument arg, Block *block) {
54 return arg.getOwner() != block;
55}
56
57/// Constructs vector type for element type.
58static VectorType vectorType(VL vl, Type etp) {
59 return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
60}
61
62/// Constructs vector type from a memref value.
63static VectorType vectorType(VL vl, Value mem) {
64 return vectorType(vl, getMemRefType(mem).getElementType());
65}
66
67/// Constructs vector iteration mask.
68static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl,
69 Value iv, Value lo, Value hi, Value step) {
70 VectorType mtp = vectorType(vl, rewriter.getI1Type());
71 // Special case if the vector length evenly divides the trip count (for
72 // example, "for i = 0, 128, 16"). A constant all-true mask is generated
73 // so that all subsequent masked memory operations are immediately folded
74 // into unconditional memory operations.
75 IntegerAttr loInt, hiInt, stepInt;
76 if (matchPattern(lo, m_Constant(&loInt)) &&
77 matchPattern(hi, m_Constant(&hiInt)) &&
78 matchPattern(step, m_Constant(&stepInt))) {
79 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
80 Value trueVal = constantI1(rewriter, loc, true);
81 return vector::BroadcastOp::create(rewriter, loc, mtp, trueVal);
82 }
83 }
84 // Otherwise, generate a vector mask that avoids overrunning the upperbound
85 // during vector execution. Here we rely on subsequent loop optimizations to
86 // avoid executing the mask in all iterations, for example, by splitting the
87 // loop into an unconditional vector loop and a scalar cleanup loop.
88 auto min = AffineMap::get(
89 /*dimCount=*/2, /*symbolCount=*/1,
90 {rewriter.getAffineSymbolExpr(0),
91 rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)},
92 rewriter.getContext());
93 Value end = rewriter.createOrFold<affine::AffineMinOp>(
94 loc, min, ValueRange{hi, iv, step});
95 return vector::CreateMaskOp::create(rewriter, loc, mtp, end);
96}
97
98/// Generates a vectorized invariant. Here we rely on subsequent loop
99/// optimizations to hoist the invariant broadcast out of the vector loop.
100static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl,
101 Value val) {
102 VectorType vtp = vectorType(vl, val.getType());
103 return vector::BroadcastOp::create(rewriter, val.getLoc(), vtp, val);
104}
105
106/// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi],
107/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
108/// that the sparsifier can only generate indirect loads in
109/// the last index, i.e. back().
110static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl,
111 Value mem, ArrayRef<Value> idxs, Value vmask) {
112 VectorType vtp = vectorType(vl, mem);
113 Value pass = constantZero(rewriter, loc, vtp);
114 if (llvm::isa<VectorType>(idxs.back().getType())) {
115 SmallVector<Value> scalarArgs(idxs);
116 Value indexVec = idxs.back();
117 scalarArgs.back() = constantIndex(rewriter, loc, 0);
118 return vector::GatherOp::create(rewriter, loc, vtp, mem, scalarArgs,
119 indexVec, vmask, pass);
120 }
121 return vector::MaskedLoadOp::create(rewriter, loc, vtp, mem, idxs, vmask,
122 pass);
123}
124
125/// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs
126/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note
127/// that the sparsifier can only generate indirect stores in
128/// the last index, i.e. back().
129static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
130 ArrayRef<Value> idxs, Value vmask, Value rhs) {
131 if (llvm::isa<VectorType>(idxs.back().getType())) {
132 SmallVector<Value> scalarArgs(idxs);
133 Value indexVec = idxs.back();
134 scalarArgs.back() = constantIndex(rewriter, loc, 0);
135 vector::ScatterOp::create(rewriter, loc, /*resultType=*/nullptr, mem,
136 scalarArgs, indexVec, vmask, rhs);
137 return;
138 }
139 vector::MaskedStoreOp::create(rewriter, loc, mem, idxs, vmask, rhs);
140}
141
142/// Detects a vectorizable reduction operations and returns the
143/// combining kind of reduction on success in `kind`.
144static bool isVectorizableReduction(Value red, Value iter,
145 vector::CombiningKind &kind) {
146 if (auto addf = red.getDefiningOp<arith::AddFOp>()) {
147 kind = vector::CombiningKind::ADD;
148 return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
149 }
150 if (auto addi = red.getDefiningOp<arith::AddIOp>()) {
151 kind = vector::CombiningKind::ADD;
152 return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
153 }
154 if (auto subf = red.getDefiningOp<arith::SubFOp>()) {
155 kind = vector::CombiningKind::ADD;
156 return subf->getOperand(0) == iter;
157 }
158 if (auto subi = red.getDefiningOp<arith::SubIOp>()) {
159 kind = vector::CombiningKind::ADD;
160 return subi->getOperand(0) == iter;
161 }
162 if (auto mulf = red.getDefiningOp<arith::MulFOp>()) {
163 kind = vector::CombiningKind::MUL;
164 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
165 }
166 if (auto muli = red.getDefiningOp<arith::MulIOp>()) {
167 kind = vector::CombiningKind::MUL;
168 return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
169 }
170 if (auto andi = red.getDefiningOp<arith::AndIOp>()) {
171 kind = vector::CombiningKind::AND;
172 return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
173 }
174 if (auto ori = red.getDefiningOp<arith::OrIOp>()) {
175 kind = vector::CombiningKind::OR;
176 return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
177 }
178 if (auto xori = red.getDefiningOp<arith::XOrIOp>()) {
179 kind = vector::CombiningKind::XOR;
180 return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
181 }
182 return false;
183}
184
185/// Generates an initial value for a vector reduction, following the scheme
186/// given in Chapter 5 of "The Software Vectorization Handbook", where the
187/// initial scalar value is correctly embedded in the vector reduction value,
188/// and a straightforward horizontal reduction will complete the operation.
189/// Value 'r' denotes the initial value of the reduction outside the loop.
190static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
191 Value red, Value iter, Value r,
192 VectorType vtp) {
193 vector::CombiningKind kind;
194 if (!isVectorizableReduction(red, iter, kind))
195 llvm_unreachable("unknown reduction");
196 switch (kind) {
197 case vector::CombiningKind::ADD:
198 case vector::CombiningKind::XOR:
199 // Initialize reduction vector to: | 0 | .. | 0 | r |
200 return vector::InsertOp::create(rewriter, loc, r,
201 constantZero(rewriter, loc, vtp),
202 constantIndex(rewriter, loc, 0));
203 case vector::CombiningKind::MUL:
204 // Initialize reduction vector to: | 1 | .. | 1 | r |
205 return vector::InsertOp::create(rewriter, loc, r,
206 constantOne(rewriter, loc, vtp),
207 constantIndex(rewriter, loc, 0));
208 case vector::CombiningKind::AND:
209 case vector::CombiningKind::OR:
210 // Initialize reduction vector to: | r | .. | r | r |
211 return vector::BroadcastOp::create(rewriter, loc, vtp, r);
212 default:
213 break;
214 }
215 llvm_unreachable("unknown reduction kind");
216}
217
218/// This method is called twice to analyze and rewrite the given subscripts.
219/// The first call (!codegen) does the analysis. Then, on success, the second
220/// call (codegen) yields the proper vector form in the output parameter
221/// vector 'idxs'. This mechanism ensures that analysis and rewriting code
222/// stay in sync. Note that the analyis part is simple because the sparsifier
223/// only generates relatively simple subscript expressions.
224///
225/// See https://llvm.org/docs/GetElementPtr.html for some background on
226/// the complications described below.
227///
228/// We need to generate a position/coordinate load from the sparse storage
229/// scheme. Narrower data types need to be zero extended before casting
230/// the value into the `index` type used for looping and indexing.
231///
232/// For the scalar case, subscripts simply zero extend narrower indices
233/// into 64-bit values before casting to an index type without a performance
234/// penalty. Indices that already are 64-bit, in theory, cannot express the
235/// full range since the LLVM backend defines addressing in terms of an
236/// unsigned pointer/signed index pair.
237static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp,
238 VL vl, ValueRange subs, bool codegen,
239 Value vmask, SmallVectorImpl<Value> &idxs) {
240 unsigned d = 0;
241 unsigned dim = subs.size();
242 Block *block = &forOp.getRegion().front();
243 for (auto sub : subs) {
244 bool innermost = ++d == dim;
245 // Invariant subscripts in outer dimensions simply pass through.
246 // Note that we rely on LICM to hoist loads where all subscripts
247 // are invariant in the innermost loop.
248 // Example:
249 // a[inv][i] for inv
250 if (isInvariantValue(sub, block)) {
251 if (innermost)
252 return false;
253 if (codegen)
254 idxs.push_back(sub);
255 continue; // success so far
256 }
257 // Invariant block arguments (including outer loop indices) in outer
258 // dimensions simply pass through. Direct loop indices in the
259 // innermost loop simply pass through as well.
260 // Example:
261 // a[i][j] for both i and j
262 if (auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
263 if (isInvariantArg(arg, block) == innermost)
264 return false;
265 if (codegen)
266 idxs.push_back(sub);
267 continue; // success so far
268 }
269 // Look under the hood of casting.
270 auto cast = sub;
271 while (true) {
272 if (auto icast = cast.getDefiningOp<arith::IndexCastOp>())
273 cast = icast->getOperand(0);
274 else if (auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
275 cast = ecast->getOperand(0);
276 else
277 break;
278 }
279 // Since the index vector is used in a subsequent gather/scatter
280 // operations, which effectively defines an unsigned pointer + signed
281 // index, we must zero extend the vector to an index width. For 8-bit
282 // and 16-bit values, an 32-bit index width suffices. For 32-bit values,
283 // zero extending the elements into 64-bit loses some performance since
284 // the 32-bit indexed gather/scatter is more efficient than the 64-bit
285 // index variant (if the negative 32-bit index space is unused, the
286 // enableSIMDIndex32 flag can preserve this performance). For 64-bit
287 // values, there is no good way to state that the indices are unsigned,
288 // which creates the potential of incorrect address calculations in the
289 // unlikely case we need such extremely large offsets.
290 // Example:
291 // a[ ind[i] ]
292 if (auto load = cast.getDefiningOp<memref::LoadOp>()) {
293 if (!innermost)
294 return false;
295 if (codegen) {
296 SmallVector<Value> idxs2(load.getIndices()); // no need to analyze
297 Location loc = forOp.getLoc();
298 Value vload =
299 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
300 Type etp = llvm::cast<VectorType>(vload.getType()).getElementType();
301 if (!llvm::isa<IndexType>(etp)) {
302 if (etp.getIntOrFloatBitWidth() < 32)
303 vload = arith::ExtUIOp::create(
304 rewriter, loc, vectorType(vl, rewriter.getI32Type()), vload);
305 else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32)
306 vload = arith::ExtUIOp::create(
307 rewriter, loc, vectorType(vl, rewriter.getI64Type()), vload);
308 }
309 idxs.push_back(vload);
310 }
311 continue; // success so far
312 }
313 // Address calculation 'i = add inv, idx' (after LICM).
314 // Example:
315 // a[base + i]
316 if (auto load = cast.getDefiningOp<arith::AddIOp>()) {
317 Value inv = load.getOperand(0);
318 Value idx = load.getOperand(1);
319 // Swap non-invariant.
320 if (!isInvariantValue(inv, block)) {
321 inv = idx;
322 idx = load.getOperand(0);
323 }
324 // Inspect.
325 if (isInvariantValue(inv, block)) {
326 if (auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
327 if (isInvariantArg(arg, block) || !innermost)
328 return false;
329 if (codegen)
330 idxs.push_back(
331 arith::AddIOp::create(rewriter, forOp.getLoc(), inv, idx));
332 continue; // success so far
333 }
334 }
335 }
336 return false;
337 }
338 return true;
339}
340
341#define UNAOP(xxx) \
342 if (isa<xxx>(def)) { \
343 if (codegen) \
344 vexp = xxx::create(rewriter, loc, vx); \
345 return true; \
346 }
347
348#define TYPEDUNAOP(xxx) \
349 if (auto x = dyn_cast<xxx>(def)) { \
350 if (codegen) { \
351 VectorType vtp = vectorType(vl, x.getType()); \
352 vexp = xxx::create(rewriter, loc, vtp, vx); \
353 } \
354 return true; \
355 }
356
357#define BINOP(xxx) \
358 if (isa<xxx>(def)) { \
359 if (codegen) \
360 vexp = xxx::create(rewriter, loc, vx, vy); \
361 return true; \
362 }
363
364/// This method is called twice to analyze and rewrite the given expression.
365/// The first call (!codegen) does the analysis. Then, on success, the second
366/// call (codegen) yields the proper vector form in the output parameter 'vexp'.
367/// This mechanism ensures that analysis and rewriting code stay in sync. Note
368/// that the analyis part is simple because the sparsifier only generates
369/// relatively simple expressions inside the for-loops.
370static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
371 Value exp, bool codegen, Value vmask, Value &vexp) {
372 Location loc = forOp.getLoc();
373 // Reject unsupported types.
374 if (!VectorType::isValidElementType(exp.getType()))
375 return false;
376 // A block argument is invariant/reduction/index.
377 if (auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
378 if (arg == forOp.getInductionVar()) {
379 // We encountered a single, innermost index inside the computation,
380 // such as a[i] = i, which must convert to [i, i+1, ...].
381 if (codegen) {
382 VectorType vtp = vectorType(vl, arg.getType());
383 Value veci = vector::BroadcastOp::create(rewriter, loc, vtp, arg);
384 Value incr = vector::StepOp::create(rewriter, loc, vtp);
385 vexp = arith::AddIOp::create(rewriter, loc, veci, incr);
386 }
387 return true;
388 }
389 // An invariant or reduction. In both cases, we treat this as an
390 // invariant value, and rely on later replacing and folding to
391 // construct a proper reduction chain for the latter case.
392 if (codegen)
393 vexp = genVectorInvariantValue(rewriter, vl, exp);
394 return true;
395 }
396 // Something defined outside the loop-body is invariant.
397 Operation *def = exp.getDefiningOp();
398 Block *block = &forOp.getRegion().front();
399 if (def->getBlock() != block) {
400 if (codegen)
401 vexp = genVectorInvariantValue(rewriter, vl, exp);
402 return true;
403 }
404 // Proper load operations. These are either values involved in the
405 // actual computation, such as a[i] = b[i] becomes a[lo:hi] = b[lo:hi],
406 // or coordinate values inside the computation that are now fetched from
407 // the sparse storage coordinates arrays, such as a[i] = i becomes
408 // a[lo:hi] = ind[lo:hi], where 'lo' denotes the current index
409 // and 'hi = lo + vl - 1'.
410 if (auto load = dyn_cast<memref::LoadOp>(def)) {
411 auto subs = load.getIndices();
413 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
414 if (codegen)
415 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
416 return true;
417 }
418 return false;
419 }
420 // Inside loop-body unary and binary operations. Note that it would be
421 // nicer if we could somehow test and build the operations in a more
422 // concise manner than just listing them all (although this way we know
423 // for certain that they can vectorize).
424 //
425 // TODO: avoid visiting CSEs multiple times
426 //
427 if (def->getNumOperands() == 1) {
428 Value vx;
429 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
430 vx)) {
431 UNAOP(math::AbsFOp)
432 UNAOP(math::AbsIOp)
433 UNAOP(math::CeilOp)
434 UNAOP(math::FloorOp)
435 UNAOP(math::SqrtOp)
436 UNAOP(math::ExpM1Op)
437 UNAOP(math::Log1pOp)
438 UNAOP(math::SinOp)
439 UNAOP(math::TanhOp)
440 UNAOP(arith::NegFOp)
441 TYPEDUNAOP(arith::TruncFOp)
442 TYPEDUNAOP(arith::ExtFOp)
443 TYPEDUNAOP(arith::FPToSIOp)
444 TYPEDUNAOP(arith::FPToUIOp)
445 TYPEDUNAOP(arith::SIToFPOp)
446 TYPEDUNAOP(arith::UIToFPOp)
447 TYPEDUNAOP(arith::ExtSIOp)
448 TYPEDUNAOP(arith::ExtUIOp)
449 TYPEDUNAOP(arith::IndexCastOp)
450 TYPEDUNAOP(arith::TruncIOp)
451 TYPEDUNAOP(arith::BitcastOp)
452 // TODO: complex?
453 }
454 } else if (def->getNumOperands() == 2) {
455 Value vx, vy;
456 if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask,
457 vx) &&
458 vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask,
459 vy)) {
460 // We only accept shift-by-invariant (where the same shift factor applies
461 // to all packed elements). In the vector dialect, this is still
462 // represented with an expanded vector at the right-hand-side, however,
463 // so that we do not have to special case the code generation.
464 if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
465 isa<arith::ShRSIOp>(def)) {
466 Value shiftFactor = def->getOperand(1);
467 if (!isInvariantValue(shiftFactor, block))
468 return false;
469 }
470 // Generate code.
471 BINOP(arith::MulFOp)
472 BINOP(arith::MulIOp)
473 BINOP(arith::DivFOp)
474 BINOP(arith::DivSIOp)
475 BINOP(arith::DivUIOp)
476 BINOP(arith::AddFOp)
477 BINOP(arith::AddIOp)
478 BINOP(arith::SubFOp)
479 BINOP(arith::SubIOp)
480 BINOP(arith::AndIOp)
481 BINOP(arith::OrIOp)
482 BINOP(arith::XOrIOp)
483 BINOP(arith::ShLIOp)
484 BINOP(arith::ShRUIOp)
485 BINOP(arith::ShRSIOp)
486 // TODO: complex?
487 }
488 }
489 return false;
490}
491
492#undef UNAOP
493#undef TYPEDUNAOP
494#undef BINOP
495
496/// This method is called twice to analyze and rewrite the given for-loop.
497/// The first call (!codegen) does the analysis. Then, on success, the second
498/// call (codegen) rewriters the IR into vector form. This mechanism ensures
499/// that analysis and rewriting code stay in sync.
500static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
501 bool codegen) {
502 Block &block = forOp.getRegion().front();
503 // For loops with single yield statement (as below) could be generated
504 // when custom reduce is used with unary operation.
505 // for (...)
506 // yield c_0
507 if (block.getOperations().size() <= 1)
508 return false;
509
510 Location loc = forOp.getLoc();
511 scf::YieldOp yield = cast<scf::YieldOp>(block.getTerminator());
512 auto &last = *++block.rbegin();
513 scf::ForOp forOpNew;
514
515 // Perform initial set up during codegen (we know that the first analysis
516 // pass was successful). For reductions, we need to construct a completely
517 // new for-loop, since the incoming and outgoing reduction type
518 // changes into SIMD form. For stores, we can simply adjust the stride
519 // and insert in the existing for-loop. In both cases, we set up a vector
520 // mask for all operations which takes care of confining vectors to
521 // the original iteration space (later cleanup loops or other
522 // optimizations can take care of those).
523 Value vmask;
524 if (codegen) {
525 Value step = constantIndex(rewriter, loc, vl.vectorLength);
526 if (vl.enableVLAVectorization) {
527 Value vscale =
528 vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
529 step = arith::MulIOp::create(rewriter, loc, vscale, step);
530 }
531 if (!yield.getResults().empty()) {
532 Value init = forOp.getInitArgs()[0];
533 VectorType vtp = vectorType(vl, init.getType());
534 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
535 forOp.getRegionIterArg(0), init, vtp);
536 forOpNew =
537 scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
538 forOp.getUpperBound(), step, vinit,
539 /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
540 forOpNew->setAttr(
543 rewriter.setInsertionPointToStart(forOpNew.getBody());
544 } else {
545 rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); });
546 rewriter.setInsertionPoint(yield);
547 }
548 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
549 forOp.getLowerBound(), forOp.getUpperBound(), step);
550 }
551
552 // Sparse for-loops either are terminated by a non-empty yield operation
553 // (reduction loop) or otherwise by a store operation (pararallel loop).
554 if (!yield.getResults().empty()) {
555 // Analyze/vectorize reduction.
556 if (yield->getNumOperands() != 1)
557 return false;
558 Value red = yield->getOperand(0);
559 Value iter = forOp.getRegionIterArg(0);
560 vector::CombiningKind kind;
561 Value vrhs;
562 if (isVectorizableReduction(red, iter, kind) &&
563 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
564 if (codegen) {
565 Value partial = forOpNew.getResult(0);
566 Value vpass = genVectorInvariantValue(rewriter, vl, iter);
567 Value vred = arith::SelectOp::create(rewriter, loc, vmask, vrhs, vpass);
568 scf::YieldOp::create(rewriter, loc, vred);
569 rewriter.setInsertionPointAfter(forOpNew);
570 Value vres = vector::ReductionOp::create(rewriter, loc, kind, partial);
571 // Now do some relinking (last one is not completely type safe
572 // but all bad ones are removed right away). This also folds away
573 // nop broadcast operations.
574 rewriter.replaceAllUsesWith(forOp.getResult(0), vres);
575 rewriter.replaceAllUsesWith(forOp.getInductionVar(),
576 forOpNew.getInductionVar());
577 rewriter.replaceAllUsesWith(forOp.getRegionIterArg(0),
578 forOpNew.getRegionIterArg(0));
579 rewriter.eraseOp(forOp);
580 }
581 return true;
582 }
583 } else if (auto store = dyn_cast<memref::StoreOp>(last)) {
584 // Analyze/vectorize store operation.
585 auto subs = store.getIndices();
587 Value rhs = store.getValue();
588 Value vrhs;
589 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
590 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
591 if (codegen) {
592 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
593 rewriter.eraseOp(store);
594 }
595 return true;
596 }
597 }
598
599 assert(!codegen && "cannot call codegen when analysis failed");
600 return false;
601}
602
603/// Basic for-loop vectorizer.
604struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
605public:
606 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
607
608 ForOpRewriter(MLIRContext *context, unsigned vectorLength,
609 bool enableVLAVectorization, bool enableSIMDIndex32)
610 : OpRewritePattern(context),
611 vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
612
613 LogicalResult matchAndRewrite(scf::ForOp op,
614 PatternRewriter &rewriter) const override {
615 // Check for single block, unit-stride for-loop that is generated by
616 // sparsifier, which means no data dependence analysis is required,
617 // and its loop-body is very restricted in form.
618 if (!op.getRegion().hasOneBlock() || !isOneInteger(op.getStep()) ||
620 return failure();
621 // Analyze (!codegen) and rewrite (codegen) loop-body.
622 if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) &&
623 vectorizeStmt(rewriter, op, vl, /*codegen=*/true))
624 return success();
625 return failure();
626 }
627
628private:
629 const VL vl;
630};
631
632static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
633 Value inp) {
634 if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
635 if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
636 if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
637 rewriter.replaceOp(op, redOp.getVector());
638 return success();
639 }
640 }
641 }
642 return failure();
643}
644
645/// Reduction chain cleanup.
646/// v = for { }
647/// s = vsum(v) v = for { }
648/// u = broadcast(s) -> for (v) { }
649/// for (u) { }
650struct ReducChainBroadcastRewriter
651 : public OpRewritePattern<vector::BroadcastOp> {
652public:
653 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
654
655 LogicalResult matchAndRewrite(vector::BroadcastOp op,
656 PatternRewriter &rewriter) const override {
657 return cleanReducChain(rewriter, op, op.getSource());
658 }
659};
660
661/// Reduction chain cleanup.
662/// v = for { }
663/// s = vsum(v) v = for { }
664/// u = insert(s) -> for (v) { }
665/// for (u) { }
666struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
667public:
668 using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
669
670 LogicalResult matchAndRewrite(vector::InsertOp op,
671 PatternRewriter &rewriter) const override {
672 return cleanReducChain(rewriter, op, op.getValueToStore());
673 }
674};
675} // namespace
676
677//===----------------------------------------------------------------------===//
678// Public method for populating vectorization rules.
679//===----------------------------------------------------------------------===//
680
681/// Populates the given patterns list with vectorization rules.
683 unsigned vectorLength,
684 bool enableVLAVectorization,
685 bool enableSIMDIndex32) {
686 assert(vectorLength > 0);
688 patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
689 enableVLAVectorization, enableSIMDIndex32);
690 patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
691 patterns.getContext());
692}
return success()
static Type getElementType(Type type)
Determine the element type of type.
auto load
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define UNAOP(xxx)
#define BINOP(xxx)
#define TYPEDUNAOP(xxx)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Definition Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
reverse_iterator rbegin()
Definition Block.h:145
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
unsigned getNumOperands()
Definition Operation.h:346
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static constexpr llvm::StringLiteral getLoopEmitterLoopAttrName()
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...