MLIR 23.0.0git
ParallelLoopFusion.cpp
Go to the documentation of this file.
1//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop fusion on parallel loops.
10//
11//===----------------------------------------------------------------------===//
12
14
24#include "mlir/IR/Builders.h"
26#include "mlir/IR/IRMapping.h"
27#include "mlir/IR/Matchers.h"
32
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/SetVector.h"
35#include "llvm/ADT/SmallBitVector.h"
36#include "llvm/ADT/TypeSwitch.h"
37
38#include <optional>
39#include <tuple>
40
41namespace mlir {
42#define GEN_PASS_DEF_SCFPARALLELLOOPFUSION
43#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
44} // namespace mlir
45
46using namespace mlir;
47using namespace mlir::scf;
48
49/// Verify there are no nested ParallelOps.
50static bool hasNestedParallelOp(ParallelOp ploop) {
51 auto walkResult =
52 ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
53 return walkResult.wasInterrupted();
54}
55
56/// Verify equal iteration spaces.
57static bool equalIterationSpaces(ParallelOp firstPloop,
58 ParallelOp secondPloop) {
59 if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
60 return false;
61
62 auto matchOperands = [&](const OperandRange &lhs,
63 const OperandRange &rhs) -> bool {
64 // TODO: Extend this to support aliases and equal constants.
65 return std::equal(lhs.begin(), lhs.end(), rhs.begin());
66 };
67 return matchOperands(firstPloop.getLowerBound(),
68 secondPloop.getLowerBound()) &&
69 matchOperands(firstPloop.getUpperBound(),
70 secondPloop.getUpperBound()) &&
71 matchOperands(firstPloop.getStep(), secondPloop.getStep());
72}
73
74/// Check if both operations are the same type of memory write op and
75/// write to the same memory location (same buffer and same indices).
77 if (!op1 || !op2 || op1->getName() != op2->getName())
78 return false;
79 if (op1 == op2)
80 return true;
81 // support only these memory-writing ops for now
82 if (!isa<memref::StoreOp, vector::TransferWriteOp, vector::StoreOp>(op1))
83 return false;
84 bool opsAreIdentical =
86 .Case([&](memref::StoreOp storeOp1) {
87 auto storeOp2 = cast<memref::StoreOp>(op2);
88 return (storeOp1.getMemRef() == storeOp2.getMemRef()) &&
89 (storeOp1.getIndices() == storeOp2.getIndices());
90 })
91 .Case([&](vector::TransferWriteOp writeOp1) {
92 auto writeOp2 = cast<vector::TransferWriteOp>(op2);
93 return (writeOp1.getBase() == writeOp2.getBase()) &&
94 (writeOp1.getIndices() == writeOp2.getIndices()) &&
95 (writeOp1.getMask() == writeOp2.getMask()) &&
96 (writeOp1.getValueToStore().getType() ==
97 writeOp2.getValueToStore().getType()) &&
98 (writeOp1.getInBounds() == writeOp2.getInBounds());
99 })
100 .Case([&](vector::StoreOp vecStoreOp1) {
101 auto vecStoreOp2 = cast<vector::StoreOp>(op2);
102 return (vecStoreOp1.getBase() == vecStoreOp2.getBase()) &&
103 (vecStoreOp1.getIndices() == vecStoreOp2.getIndices()) &&
104 (vecStoreOp1.getValueToStore().getType() ==
105 vecStoreOp2.getValueToStore().getType()) &&
106 (vecStoreOp1.getAlignment() == vecStoreOp2.getAlignment()) &&
107 (vecStoreOp1.getNontemporal() ==
108 vecStoreOp2.getNontemporal());
109 })
110 .Default([](Operation *) { return false; });
111 return opsAreIdentical;
112}
113
114/// Check if val1 (from the first parallel loop) and val2 (from the
115/// second) are equivalent, considering the mapping of induction variables from
116/// the first to the second parallel loop.
117static bool valsAreEquivalent(Value val1, Value val2,
118 const IRMapping &loopsIVsMap) {
119 if (val1 == val2 || loopsIVsMap.lookupOrDefault(val1) == val2 ||
120 loopsIVsMap.lookupOrDefault(val2) == val1)
121 return true;
122 Operation *val1DefOp = val1.getDefiningOp();
123 Operation *val2DefOp = val2.getDefiningOp();
124 if (!val1DefOp || !val2DefOp)
125 return false;
126 if (!isMemoryEffectFree(val1DefOp) || !isMemoryEffectFree(val2DefOp))
127 return false;
129 val1DefOp, val2DefOp,
130 [&](Value v1, Value v2) {
131 return success(loopsIVsMap.lookupOrDefault(v1) == v2 ||
132 loopsIVsMap.lookupOrDefault(v2) == v1);
133 },
134 /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations);
135}
136
137/// If the `expr` value is the result of an integer addition of `base` and a
138/// constant, return the constant.
139static std::optional<int64_t> getAddConstant(Value expr, Value base,
140 const IRMapping &loopsIVsMap) {
141 if (auto addOp = expr.getDefiningOp<arith::AddIOp>()) {
142 if (auto constOp = getConstantIntValue(addOp.getLhs());
143 constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
144 return constOp.value();
145 if (auto constOp = getConstantIntValue(addOp.getRhs());
146 constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
147 return constOp.value();
148 return std::nullopt;
149 }
150
151 if (auto addOp = expr.getDefiningOp<index::AddOp>()) {
152 if (auto constOp = getConstantIntValue(addOp.getLhs());
153 constOp && valsAreEquivalent(addOp.getRhs(), base, loopsIVsMap))
154 return constOp.value();
155 if (auto constOp = getConstantIntValue(addOp.getRhs());
156 constOp && valsAreEquivalent(addOp.getLhs(), base, loopsIVsMap))
157 return constOp.value();
158 return std::nullopt;
159 }
160
161 if (auto applyOp = expr.getDefiningOp<affine::AffineApplyOp>()) {
162 AffineMap map = applyOp.getAffineMap();
163 if (map.getNumResults() != 1 || map.getNumDims() != 1 ||
164 map.getNumSymbols() != 0)
165 return std::nullopt;
166 if (!valsAreEquivalent(applyOp.getOperand(0), base, loopsIVsMap))
167 return std::nullopt;
168 AffineExpr result = map.getResult(0);
169 auto bin = dyn_cast<AffineBinaryOpExpr>(result);
170 if (!bin || bin.getKind() != AffineExprKind::Add)
171 return std::nullopt;
172 auto lhsDim = dyn_cast<AffineDimExpr>(bin.getLHS());
173 auto rhsDim = dyn_cast<AffineDimExpr>(bin.getRHS());
174 auto lhsConst = dyn_cast<AffineConstantExpr>(bin.getLHS());
175 auto rhsConst = dyn_cast<AffineConstantExpr>(bin.getRHS());
176 if (lhsConst && rhsDim)
177 return lhsConst.getValue();
178 if (rhsConst && lhsDim)
179 return rhsConst.getValue();
180 }
181 return std::nullopt;
182}
183
184// Return true if the scalar load index may hit any element covered by a
185// vector.store/transfer_write along a single memref dimension. Supported cases:
186//
187// 1) Direct index match (with optional offset):
188// vector.transfer_write %v, %A[%i] : vector<4xf32>, memref<...>
189// %x = memref.load %A[%i] : memref<...>
190//
191// 2) Loop IV range intersects the write range:
192// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
193// scf.for %k = %c0 to %c4 step %c1 { %x = memref.load %A[%k] }
194//
195// 3) Constant index (or IV + constant) within the write range:
196// vector.transfer_write %v, %A[%c0] : vector<4xf32>, memref<...>
197// %x = memref.load %A[%c2] : memref<...>
198// %y = memref.load %A[%i + %c1] : memref<...>
199//
200// Args:
201// - loadIndex: index used by the scalar load for this dimension.
202// - offset: subview offset for the base memref dimension (if any).
203// - writeIndex: index used by the transfer_write for this dimension. Can be
204// null if the dim was dropped by a rank reducing subview, whose result is
205// written by the vector.write.
206// - extent: vector size along this dimension (number of elements written).
207// - loopsIVsMap: IV equivalence map between fused loops.
208static bool loadIndexWithinWriteRange(Value loadIndex, OpFoldResult offset,
209 Value writeIndex, int64_t extent,
210 const IRMapping &loopsIVsMap) {
211 if (extent <= 0)
212 return false;
213
214 // Extract constant loop bounds for loop IVs (e.g. from scf.for).
215 auto getConstLoopBoundsForIV =
216 [](Value index) -> std::optional<std::tuple<int64_t, int64_t, int64_t>> {
217 auto blockArg = dyn_cast<BlockArgument>(index);
218 if (!blockArg)
219 return std::nullopt;
220 auto *parentOp = blockArg.getOwner()->getParentOp();
221 auto loopLike = dyn_cast<LoopLikeOpInterface>(parentOp);
222 if (!loopLike)
223 return std::nullopt;
224 auto ranges = getConstLoopBounds(loopLike);
225 if (ranges.empty())
226 return std::nullopt;
227
228 auto ivs = loopLike.getLoopInductionVars();
229 if (!ivs)
230 return std::nullopt;
231 auto it = llvm::find(*ivs, blockArg);
232 if (it == ivs->end())
233 return std::nullopt;
234 unsigned pos = std::distance(ivs->begin(), it);
235 if (pos >= ranges.size())
236 return std::nullopt;
237 auto [lb, ub, step] = ranges[pos];
238 return std::make_tuple(lb, ub, step);
239 };
240
241 std::optional<int64_t> offsetConst = getConstantIntValue(offset);
242 std::optional<int64_t> writeConst =
243 writeIndex ? getConstantIntValue(writeIndex) : std::optional<int64_t>(0);
244 if (!writeConst && writeIndex) {
245 // Treat single-iteration IVs as constants for matching.
246 if (auto bounds = getConstLoopBoundsForIV(writeIndex)) {
247 auto [lb, ub, step] = *bounds;
248 if (step > 0 && ub == lb + step)
249 writeConst = lb;
250 }
251 }
252
253 // Check whether a loop IV is fully contained in a constant write range.
254 auto loopIVWithinRange = [](int64_t lb, int64_t ub, int64_t step,
255 int64_t rangeStart, int64_t rangeExtent) -> bool {
256 if (rangeExtent <= 0 || step <= 0)
257 return false;
258 if (ub <= lb)
259 return false;
260 int64_t rangeEnd = rangeStart + rangeExtent;
261 return lb >= rangeStart && ub <= rangeEnd;
262 };
263
264 if (offsetConst && writeConst) {
265 // Constant start of the write range; check constant load or loop IV range.
266 int64_t start = *offsetConst + *writeConst;
267 if (auto loadConst = getConstantIntValue(loadIndex))
268 return (*loadConst >= start && *loadConst < start + extent);
269 if (auto bounds = getConstLoopBoundsForIV(loadIndex)) {
270 auto [lb, ub, step] = *bounds;
271 return loopIVWithinRange(lb, ub, step, start, extent);
272 }
273 }
274
275 if (writeIndex) {
276 // Direct IV match (or IV + constant) against the write index.
277 if (offsetConst && *offsetConst == 0 &&
278 valsAreEquivalent(loadIndex, writeIndex, loopsIVsMap))
279 return true;
280 if (auto addConst = getAddConstant(loadIndex, writeIndex, loopsIVsMap)) {
281 // Match load index of the form writeIndex + C within the write extent.
282 if (offsetConst) {
283 int64_t start = *offsetConst;
284 return (*addConst >= start && *addConst < start + extent);
285 }
286 }
287 return false;
288 }
289
290 if (auto offsetVal = dyn_cast<Value>(offset)) {
291 // Exact match when extent is 1 and the load hits the offset value.
292 if (extent == 1 && valsAreEquivalent(loadIndex, offsetVal, loopsIVsMap))
293 return true;
294 }
295
296 return false;
297}
298
299/// Return the base memref value used by the given memory op.
301 // TODO: use the common interface for memory ops once available.
303 .Case([&](memref::LoadOp load) { return load.getMemRef(); })
304 .Case([&](memref::StoreOp store) { return store.getMemRef(); })
305 .Case([&](vector::TransferReadOp read) { return read.getBase(); })
306 .Case([&](vector::TransferWriteOp write) { return write.getBase(); })
307 .Case([&](vector::LoadOp load) { return load.getBase(); })
308 .Case([&](vector::StoreOp store) { return store.getBase(); })
309 .Default([](Operation *) { return Value(); });
310}
311
312/// Recognize scalar memref.load of an element produced by a vector write
313/// (vector.transfer_write or vector.store, optionally through a rank-reducing
314/// unit-stride subview) of the same buffer. This covers the pattern where a
315/// vector write stores a full lane pack and a subsequent scalar load reads an
316/// element from that lane pack. EXAMPLE:
317/// vector.transfer_write %V, %arg[%x, %y, ..., 0] {in_bounds = [true]} :
318/// vector<4xf32>, memref<4xf32, strided<[1], offset: ?>>
319/// scf.for %iter = %c0 to %c4 step %c1 iter_args(...) -> (f32) {
320/// %0 = memref.load %arg[%x, %y, ..., %iter] : memref<1x128x16x4xf32>
321/// ...
322/// }
323///
324static bool isLoadOnWrittenVector(memref::LoadOp loadOp, Value writeBase,
325 ValueRange writeIndices, VectorType vecTy,
326 ArrayRef<int64_t> vectorDimForWriteDim,
327 const IRMapping &ivsMap) {
328 if (!vecTy)
329 return false;
330
331 Value base = writeBase;
332 // The write base if there is no subview, or the subview source otherwise.
333 MemrefValue baseMemref = nullptr;
335 llvm::SmallBitVector droppedDims;
336 bool hasSubview = false;
337 auto *ctx = loadOp.getContext();
338 if (auto subView = base.getDefiningOp<memref::SubViewOp>()) {
339 if (!subView.hasUnitStride())
340 return false;
341 baseMemref = cast<MemrefValue>(subView.getSource());
342 offsets = llvm::to_vector(subView.getMixedOffsets());
343 droppedDims = subView.getDroppedDims();
344 hasSubview = true;
345 } else {
346 baseMemref = dyn_cast<MemrefValue>(base);
347 if (!baseMemref)
348 return false;
349 }
350
351 auto loadIndices = loadOp.getIndices();
352 unsigned baseRank = baseMemref.getType().getRank();
353 if ((loadOp.getMemref() != baseMemref) || (loadIndices.size() != baseRank))
354 return false;
355
356 unsigned writeRank = writeIndices.size();
357 if ((!hasSubview && writeRank != baseRank) ||
358 (hasSubview && offsets.size() != baseRank) ||
359 (vectorDimForWriteDim.size() != writeRank))
360 return false;
361
362 auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
363 unsigned writeMemrefDim = 0;
364 for (unsigned baseDim : llvm::seq(baseRank)) {
365 bool wasDropped = (hasSubview && droppedDims.test(baseDim));
366 int64_t vectorDim = !wasDropped ? vectorDimForWriteDim[writeMemrefDim] : -1;
367 int64_t extent = 1;
368 if (vectorDim >= 0) {
369 int64_t dimSize = vecTy.getDimSize(vectorDim);
370 if (dimSize == ShapedType::kDynamic)
371 return false;
372 extent = dimSize;
373 }
374 Value writeIndex = !wasDropped ? writeIndices[writeMemrefDim] : Value();
375 OpFoldResult offset =
376 hasSubview ? offsets[baseDim] : OpFoldResult(zeroAttr);
377 if (!loadIndexWithinWriteRange(loadIndices[baseDim], offset, writeIndex,
378 extent, ivsMap))
379 return false;
380 if (!wasDropped)
381 ++writeMemrefDim;
382 }
383
384 return true;
385}
386
387/// Recognize scalar memref.load of an element produced by a
388/// vector.transfer_write
389static bool loadMatchesVectorWrite(memref::LoadOp loadOp,
390 vector::TransferWriteOp writeOp,
391 const IRMapping &ivsMap) {
392 auto vecTy = dyn_cast<VectorType>(writeOp.getVector().getType());
393 if (!vecTy)
394 return false;
395
396 unsigned writeRank = writeOp.getIndices().size();
397 AffineMap permutationMap = writeOp.getPermutationMap();
398 if (!permutationMap.isProjectedPermutation() ||
399 permutationMap.getNumResults() != vecTy.getRank() ||
400 permutationMap.getNumDims() != writeRank)
401 return false;
402
403 SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
404 for (unsigned vecDim = 0; vecDim < permutationMap.getNumResults(); ++vecDim) {
405 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(vecDim));
406 if (!dimExpr)
407 return false;
408 unsigned writeDim = dimExpr.getPosition();
409 if (writeDim >= writeRank || vectorDimForWriteDim[writeDim] != -1)
410 return false;
411 vectorDimForWriteDim[writeDim] = vecDim;
412 }
413
414 return isLoadOnWrittenVector(loadOp, writeOp.getBase(), writeOp.getIndices(),
415 vecTy, vectorDimForWriteDim, ivsMap);
416}
417
418/// Recognize scalar memref.load of an element produced by a vector.store
419static bool loadMatchesVectorStore(memref::LoadOp loadOp,
420 vector::StoreOp storeOp,
421 const IRMapping &ivsMap) {
422 auto vecTy = dyn_cast<VectorType>(storeOp.getValueToStore().getType());
423 if (!vecTy)
424 return false;
425
426 unsigned writeRank = storeOp.getIndices().size();
427 if (vecTy.getRank() > writeRank)
428 return false;
429
430 SmallVector<int64_t> vectorDimForWriteDim(writeRank, -1);
431 unsigned vecRank = vecTy.getRank();
432 for (unsigned i = 0; i < vecRank; ++i) {
433 unsigned writeDim = writeRank - vecRank + i;
434 vectorDimForWriteDim[writeDim] = i;
435 }
436
437 return isLoadOnWrittenVector(loadOp, storeOp.getBase(), storeOp.getIndices(),
438 vecTy, vectorDimForWriteDim, ivsMap);
439}
440
441/// Check if both operations access the same positions of the same
442/// buffer, but one of the two does it through a rank-reducing full subview of
443/// the buffer (the other's base). EXAMPLE:
444/// memref.store %a, %buf[%c0, %i, %j] : memref<1x2x2xf32>
445/// %alias = memref.subview %buf[0, 0, 0][1, 2, 2][1, 1, 1]: memref<1x2x2xf32>
446/// to memref<2x2xf32>
447/// %val = memref.load %alias[%i, %j] : memref<2x2xf32>
448template <typename OpTy1, typename OpTy2>
450 OpTy1 op1, OpTy2 op2, const IRMapping &firstToSecondPloopIVsMap,
451 OpBuilder &b) {
452 auto base1 = cast<MemrefValue>(getBaseMemref(op1));
453 auto base2 = cast<MemrefValue>(getBaseMemref(op2));
454 if (!base1 || !base2)
455 return false;
456
457 auto accessThroughTrivialSubviewIsSame =
458 [&b](memref::SubViewOp subView, ValueRange subViewAccess,
459 ValueRange sourceAccess, const IRMapping &ivsMap) -> bool {
460 SmallVector<Value> resolvedSubviewAccess;
461 LogicalResult resolved = resolveSourceIndicesRankReducingSubview(
462 subView.getLoc(), b, subView, subViewAccess, resolvedSubviewAccess);
463 if (failed(resolved) ||
464 (resolvedSubviewAccess.size() != sourceAccess.size()))
465 return false;
466 for (auto [dimIdx, resolvedIndex] :
467 llvm::enumerate(resolvedSubviewAccess)) {
468 if (!matchPattern(resolvedIndex, m_Zero()) &&
469 !valsAreEquivalent(resolvedIndex, sourceAccess[dimIdx], ivsMap))
470 return false;
471 }
472 return true;
473 };
474
475 // Case 1: op1 uses a subview of op2's base.
476 if (auto subView = base1.template getDefiningOp<memref::SubViewOp>();
477 subView &&
479 base2, cast<MemrefValue>(subView.getSource())) &&
480 accessThroughTrivialSubviewIsSame(subView, op1.getIndices(),
481 op2.getIndices(),
482 firstToSecondPloopIVsMap))
483 return true;
484
485 // Case 2: op2 uses a subview of op1's base.
486 if (auto subView = base2.template getDefiningOp<memref::SubViewOp>();
487 subView &&
489 base1, cast<MemrefValue>(subView.getSource())) &&
490 accessThroughTrivialSubviewIsSame(subView, op2.getIndices(),
491 op1.getIndices(),
492 firstToSecondPloopIVsMap))
493 return true;
494
495 return false;
496}
497
498/// Check if both memory read/write operations access the same indices
499/// (considering also the mapping of induction variables from the first to the
500/// second parallel loop).
501template <typename OpTy1, typename OpTy2>
502static bool opsAccessSameIndices(OpTy1 op1, OpTy2 op2,
503 const IRMapping &loopsIVsMap, OpBuilder &b) {
504 auto indices1 = op1.getIndices();
505 auto indices2 = op2.getIndices();
506 if (indices1.size() != indices2.size())
507 return opsAccessSameIndicesViaRankReducingSubview(op1, op2, loopsIVsMap, b);
508 for (auto [idx1, idx2] : llvm::zip(indices1, indices2)) {
509 if (!valsAreEquivalent(idx1, idx2, loopsIVsMap))
510 return false;
511 }
512 return true;
513}
514
515/// Check if the loadOp reads from the same memory location (same buffer,
516/// same indices and same properties) as written by the storeOp.
517static bool
519 const IRMapping &firstToSecondPloopIVsMap,
520 OpBuilder &b) {
521 if (!loadOp || !storeOp)
522 return false;
523 // Support only these memory-reading ops for now
524 if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp))
525 return false;
526 bool accessSameMemory =
528 .Case([&](memref::LoadOp memLoadOp) {
529 if (auto memStoreOp = dyn_cast<memref::StoreOp>(storeOp))
530 return opsAccessSameIndices(memLoadOp, memStoreOp,
531 firstToSecondPloopIVsMap, b);
532 if (auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp))
533 return loadMatchesVectorWrite(memLoadOp, vecWriteOp,
534 firstToSecondPloopIVsMap);
535 if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp))
536 return loadMatchesVectorStore(memLoadOp, vecStoreOp,
537 firstToSecondPloopIVsMap);
538 return false;
539 })
540 .Case([&](vector::TransferReadOp vecReadOp) {
541 auto vecWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
542 if (!vecWriteOp)
543 return false;
544 return opsAccessSameIndices(vecReadOp, vecWriteOp,
545 firstToSecondPloopIVsMap, b) &&
546 (vecReadOp.getMask() == vecWriteOp.getMask()) &&
547 (vecReadOp.getInBounds() == vecWriteOp.getInBounds());
548 })
549 .Case([&](vector::LoadOp vecLoadOp) {
550 auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
551 if (!vecStoreOp)
552 return false;
553 return opsAccessSameIndices(vecLoadOp, vecStoreOp,
554 firstToSecondPloopIVsMap, b) &&
555 (vecLoadOp.getAlignment() == vecStoreOp.getAlignment());
556 })
557 .Default([](Operation *) { return false; });
558 return accessSameMemory;
559}
560
563 .Case([&](memref::StoreOp storeOp) { return storeOp.getMemRef(); })
564 .Case([&](vector::TransferWriteOp writeOp) { return writeOp.getBase(); })
565 .Case([&](vector::StoreOp vecStoreOp) { return vecStoreOp.getBase(); })
566 .Default([](Operation *) { return Value(); });
567}
568
569/// To be called when `mayAlias(val1, val2)` is true. Check if the potential
570/// aliasing between the loadOp and storeOp can be resolved by analyzing their
571/// access patterns.
572static bool canResolveAlias(Operation *loadOp, Operation *storeOp,
573 const IRMapping &loopsIVsMap) {
574 if (auto transfWriteOp = dyn_cast<vector::TransferWriteOp>(storeOp);
575 transfWriteOp && isa<memref::LoadOp>(loadOp))
576 return loadMatchesVectorWrite(cast<memref::LoadOp>(loadOp), transfWriteOp,
577 loopsIVsMap);
578 if (auto vecStoreOp = dyn_cast<vector::StoreOp>(storeOp);
579 vecStoreOp && isa<memref::LoadOp>(loadOp))
580 return loadMatchesVectorStore(cast<memref::LoadOp>(loadOp), vecStoreOp,
581 loopsIVsMap);
582 return false;
583}
584
585/// Check that the parallel loops have no mixed access to the same buffers.
586/// Return `true` if the second parallel loop does not read or write the buffers
587/// written by the first loop using different indices.
589 ParallelOp firstPloop, ParallelOp secondPloop,
590 const IRMapping &firstToSecondPloopIndices,
592 // Map buffers to their store/write ops in the firstPloop
593 DenseMap<Value, SmallVector<Operation *>> bufferStoresInFirstPloop;
594 // Record all the memory buffers used in store/write ops found in firstPloop
595 llvm::SmallSetVector<Value, 4> buffersWrittenInFirstPloop;
596
597 auto collectStoreOpsInWalk = [&](Operation *op) {
598 auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(op);
599 // Ignore ops that don't write to memory
600 if (!memOpInterf || (!memOpInterf.hasEffect<MemoryEffects::Write>() &&
601 !memOpInterf.hasEffect<MemoryEffects::Free>()))
602 return WalkResult::advance();
603
604 // Only these memory-writing ops are supported for now:
605 // memref.store, vector.transfer_write, vector.store
606 Value storeOpBase = getStoreOpTargetBuffer(op);
607 if (!storeOpBase)
608 return WalkResult::interrupt();
609
610 // Expect the base operand to be a Memref
611 MemrefValue storeOpBaseMemref = dyn_cast<MemrefValue>(storeOpBase);
612 if (!storeOpBaseMemref)
613 return WalkResult::interrupt();
614 // Get the original memref buffer, skipping full view-like ops
615 Value buffer = memref::skipFullyAliasingOperations(storeOpBaseMemref);
616 bufferStoresInFirstPloop[buffer].push_back(op);
617 buffersWrittenInFirstPloop.insert(buffer);
618 return WalkResult::advance();
619 };
620
621 // Walk the first parallel loop to collect all store/write ops and their
622 // target buffers
623 if (firstPloop.getBody()->walk(collectStoreOpsInWalk).wasInterrupted())
624 return false;
625
626 // Check that this load/read op encountered while walking the second parallel
627 // loop does not have incompatible data dependencies with the store/write ops
628 // collected from the first parallel loop: the loops can be fused only if in
629 // the 2nd loop there are no loads/stores from/to the buffers written in the
630 // 1st loop, except when on the same exact memory location (same indices) as
631 // written in the 1st loop.
632 auto checkLoadInWalkHasNoIncompatibleDataDeps = [&](Operation *loadOp) {
633 auto memOpInterf = dyn_cast_if_present<MemoryEffectOpInterface>(loadOp);
634 // To be conservative, we should stop on ops that don't advertise their
635 // memory effects. However, many ops don't implement MemoryEffectOpInterface
636 // yet, so for now we just skip them.
637 // TODO: once more ops add MemoryEffectOpInterface, interrupt the walk here.
638 if (!memOpInterf &&
639 !loadOp->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>())
640 return WalkResult::advance();
641 // Ignore ops that don't read from memory, and wrapping ops that have nested
642 // memory effects (e.g. loops, conditionals) as they will be analyzed when
643 // visiting their nested ops.
644 if ((!memOpInterf &&
645 loadOp->hasTrait<mlir::OpTrait::HasRecursiveMemoryEffects>()) ||
646 (memOpInterf && !memOpInterf.hasEffect<MemoryEffects::Read>()))
647 return WalkResult::advance();
648 // Support only these memory-reading ops for now
649 if (!isa<memref::LoadOp, vector::TransferReadOp, vector::LoadOp>(loadOp) ||
650 !isa<MemrefValue>(loadOp->getOperand(0)))
651 return WalkResult::interrupt();
652
653 MemrefValue loadOpBase = cast<MemrefValue>(loadOp->getOperand(0));
654 MemrefValue loadedOrigBuf = memref::skipFullyAliasingOperations(loadOpBase);
655
656 for (Value storedMem : buffersWrittenInFirstPloop)
657 if ((storedMem != loadedOrigBuf) && mayAlias(storedMem, loadedOrigBuf) &&
658 !llvm::all_of(bufferStoresInFirstPloop[storedMem],
659 [&](Operation *storeOp) {
660 return canResolveAlias(loadOp, storeOp,
661 firstToSecondPloopIndices);
662 })) {
663 return WalkResult::interrupt();
664 }
665
666 auto writeOpsIt = bufferStoresInFirstPloop.find(loadedOrigBuf);
667 if (writeOpsIt == bufferStoresInFirstPloop.end())
668 return WalkResult::advance();
669 // Store/write ops to this buffer in the firstPloop
670 SmallVector<mlir::Operation *> &writeOps = writeOpsIt->second;
671
672 // If the first loop has no writes to this buffer, continue
673 if (writeOps.empty())
674 return WalkResult::advance();
675
676 Operation *writeOp = writeOps.front();
677
678 // In the first parallel loop, multiple writes to the same memref are
679 // allowed only on the same memory location
680 if (!llvm::all_of(writeOps, [&](Operation *otherWriteOp) {
681 return opsWriteSameMemLocation(writeOp, otherWriteOp);
682 })) {
683 return WalkResult::interrupt();
684 }
685
686 // Check that the load in secondPloop reads from the same memory location as
687 // written by the corresponding store in firstPloop
688 if (!loadsFromSameMemoryLocationWrittenBy(loadOp, writeOp,
689 firstToSecondPloopIndices, b)) {
690 return WalkResult::interrupt();
691 }
692
693 return WalkResult::advance();
694 };
695
696 // Walk the second parallel loop to check load/read ops against the stores
697 // collected from the first parallel loop.
698 return !secondPloop.getBody()
699 ->walk(checkLoadInWalkHasNoIncompatibleDataDeps)
700 .wasInterrupted();
701}
702
703/// Check that in each loop there are no read ops on the buffers written
704/// by the other loop, except when reading from the same exact memory location
705/// (same indices) as written in the other loop.
706static bool
707noIncompatibleDataDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
708 const IRMapping &firstToSecondPloopIndices,
710 OpBuilder &b) {
712 firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias, b))
713 return false;
714
715 IRMapping secondToFirstPloopIndices;
716 secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
717 firstPloop.getBody()->getArguments());
719 secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias, b);
720}
721
722/// Check if fusion of the two parallel loops is legal:
723/// i.e. no nested parallel loops, equal iteration spaces,
724/// and no incompatible data dependencies between the loops.
725static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
726 const IRMapping &firstToSecondPloopIndices,
728 OpBuilder &b) {
729 return !hasNestedParallelOp(firstPloop) &&
730 !hasNestedParallelOp(secondPloop) &&
731 equalIterationSpaces(firstPloop, secondPloop) &&
732 noIncompatibleDataDependencies(firstPloop, secondPloop,
733 firstToSecondPloopIndices, mayAlias, b);
734}
735
736/// Prepend operations of firstPloop's body into secondPloop's body.
737/// Update secondPloop with new loop.
738static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
739 OpBuilder builder,
741 Block *block1 = firstPloop.getBody();
742 Block *block2 = secondPloop.getBody();
743 IRMapping firstToSecondPloopIndices;
744 firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
745
746 if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
747 mayAlias, builder))
748 return;
749
750 DominanceInfo dom;
751 // We are fusing first loop into second, make sure there are no users of the
752 // first loop results between loops.
753 for (Operation *user : firstPloop->getUsers())
754 if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
755 return;
756
757 ValueRange inits1 = firstPloop.getInitVals();
758 ValueRange inits2 = secondPloop.getInitVals();
759
760 SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
761 newInitVars.append(inits2.begin(), inits2.end());
762
763 IRRewriter b(builder);
764 b.setInsertionPoint(secondPloop);
765 auto newSecondPloop = ParallelOp::create(
766 b, secondPloop.getLoc(), secondPloop.getLowerBound(),
767 secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
768
769 Block *newBlock = newSecondPloop.getBody();
770 auto term1 = cast<ReduceOp>(block1->getTerminator());
771 auto term2 = cast<ReduceOp>(block2->getTerminator());
772
773 b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
774 newBlock->getArguments());
775 b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
776 newBlock->getArguments());
777
778 ValueRange results = newSecondPloop.getResults();
779 if (!results.empty()) {
780 b.setInsertionPointToEnd(newBlock);
781
782 ValueRange reduceArgs1 = term1.getOperands();
783 ValueRange reduceArgs2 = term2.getOperands();
784 SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
785 newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
786
787 auto newReduceOp = scf::ReduceOp::create(b, term2.getLoc(), newReduceArgs);
788
789 for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
790 term1.getReductions(), term2.getReductions()))) {
791 Block &oldRedBlock = reg.front();
792 Block &newRedBlock = newReduceOp.getReductions()[i].front();
793 b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
794 newRedBlock.getArguments());
795 }
796
797 firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
798 secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
799 }
800 term1->erase();
801 term2->erase();
802 firstPloop.erase();
803 secondPloop.erase();
804 secondPloop = newSecondPloop;
805}
806
808 Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
809 OpBuilder b(region);
810 // Consider every single block and attempt to fuse adjacent loops.
812 for (auto &block : region) {
813 ploopChains.clear();
814 ploopChains.push_back({});
815
816 // Not using `walk()` to traverse only top-level parallel loops and also
817 // make sure that there are no side-effecting ops between the parallel
818 // loops.
819 bool noSideEffects = true;
820 for (auto &op : block) {
821 if (auto ploop = dyn_cast<ParallelOp>(op)) {
822 if (noSideEffects) {
823 ploopChains.back().push_back(ploop);
824 } else {
825 ploopChains.push_back({ploop});
826 noSideEffects = true;
827 }
828 continue;
829 }
830 // TODO: Handle region side effects properly.
831 noSideEffects &= isMemoryEffectFree(&op) && op.getNumRegions() == 0;
832 }
833 for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
834 for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
835 fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
836 }
837 }
838}
839
840namespace {
841struct ParallelLoopFusion
842 : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> {
843 void runOnOperation() override {
844 auto &aa = getAnalysis<AliasAnalysis>();
845
846 auto mayAlias = [&](Value val1, Value val2) -> bool {
847 // If the memref is defined in one of the parallel loops body, careful
848 // alias analysis is needed.
849 // TODO: check if this is still needed as a separate check.
850 auto val1Def = val1.getDefiningOp();
851 auto val2Def = val2.getDefiningOp();
852 auto val1Loop =
853 val1Def ? val1Def->getParentOfType<ParallelOp>() : nullptr;
854 auto val2Loop =
855 val2Def ? val2Def->getParentOfType<ParallelOp>() : nullptr;
856 if (val1Loop != val2Loop)
857 return true;
858
859 return !aa.alias(val1, val2).isNo();
860 };
861
862 getOperation()->walk([&](Operation *child) {
863 for (Region &region : child->getRegions())
865 });
866 }
867};
868} // namespace
869
870std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {
871 return std::make_unique<ParallelLoopFusion>();
872}
return success()
static bool mayAlias(Value first, Value second)
Returns true if two values may be referencing aliasing memory.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
static bool canResolveAlias(Operation *loadOp, Operation *storeOp, const IRMapping &loopsIVsMap)
To be called when mayAlias(val1, val2) is true.
static bool equalIterationSpaces(ParallelOp firstPloop, ParallelOp secondPloop)
Verify equal iteration spaces.
static bool isLoadOnWrittenVector(memref::LoadOp loadOp, Value writeBase, ValueRange writeIndices, VectorType vecTy, ArrayRef< int64_t > vectorDimForWriteDim, const IRMapping &ivsMap)
Recognize scalar memref.load of an element produced by a vector write (vector.transfer_write or vecto...
static bool loadMatchesVectorWrite(memref::LoadOp loadOp, vector::TransferWriteOp writeOp, const IRMapping &ivsMap)
Recognize scalar memref.load of an element produced by a vector.transfer_write.
static std::optional< int64_t > getAddConstant(Value expr, Value base, const IRMapping &loopsIVsMap)
If the expr value is the result of an integer addition of base and a constant, return the constant.
static bool opsAccessSameIndices(OpTy1 op1, OpTy2 op2, const IRMapping &loopsIVsMap, OpBuilder &b)
Check if both memory read/write operations access the same indices (considering also the mapping of i...
static Value getStoreOpTargetBuffer(Operation *op)
static bool haveNoDataDependenciesExceptSameIndex(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias, OpBuilder &b)
Check that the parallel loops have no mixed access to the same buffers.
static Value getBaseMemref(Operation *op)
Return the base memref value used by the given memory op.
static bool loadsFromSameMemoryLocationWrittenBy(Operation *loadOp, Operation *storeOp, const IRMapping &firstToSecondPloopIVsMap, OpBuilder &b)
Check if the loadOp reads from the same memory location (same buffer, same indices and same propertie...
static bool loadIndexWithinWriteRange(Value loadIndex, OpFoldResult offset, Value writeIndex, int64_t extent, const IRMapping &loopsIVsMap)
static bool opsWriteSameMemLocation(Operation *op1, Operation *op2)
Check if both operations are the same type of memory write op and write to the same memory location (...
static bool noIncompatibleDataDependencies(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias, OpBuilder &b)
Check that in each loop there are no read ops on the buffers written by the other loop,...
static bool valsAreEquivalent(Value val1, Value val2, const IRMapping &loopsIVsMap)
Check if val1 (from the first parallel loop) and val2 (from the second) are equivalent,...
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, const IRMapping &firstToSecondPloopIndices, llvm::function_ref< bool(Value, Value)> mayAlias, OpBuilder &b)
Check if fusion of the two parallel loops is legal: i.e.
static bool opsAccessSameIndicesViaRankReducingSubview(OpTy1 op1, OpTy2 op2, const IRMapping &firstToSecondPloopIVsMap, OpBuilder &b)
Check if both operations access the same positions of the same buffer, but one of the two does it thr...
static bool loadMatchesVectorStore(memref::LoadOp loadOp, vector::StoreOp storeOp, const IRMapping &ivsMap)
Recognize scalar memref.load of an element produced by a vector.store.
static bool hasNestedParallelOp(ParallelOp ploop)
Verify there are no nested ParallelOps.
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, OpBuilder builder, llvm::function_ref< bool(Value, Value)> mayAlias)
Prepend operations of firstPloop's body into secondPloop's body.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
iterator begin()
Definition Block.h:153
A class for computing basic dominance information.
Definition Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
This trait indicates that the memory effects of an operation includes the effects of operations neste...
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:706
user_range getUsers()
Returns a range of all users.
Definition Operation.h:902
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
MemrefValue skipFullyAliasingOperations(MemrefValue source)
Walk up the source chain until an operation that changes/defines the view of memory is found (i....
bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b)
Checks if two (memref) values are the same or statically known to alias the same region of memory.
void naivelyFuseParallelOps(Region &region, llvm::function_ref< bool(Value, Value)> mayAlias)
Fuses all adjacent scf.parallel operations with identical bounds and step into one scf....
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
llvm::SmallVector< std::tuple< int64_t, int64_t, int64_t > > getConstLoopBounds(mlir::LoopLikeOpInterface loopOp)
Get constant loop bounds and steps for each of the induction variables of the given loop operation,...
Definition Utils.cpp:1576
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
TypedValue< BaseMemRefType > MemrefValue
A value with a memref type.
Definition MemRefUtils.h:26
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
std::unique_ptr< Pass > createParallelLoopFusionPass()
Creates a loop fusion pass which fuses parallel loops.
The following effect indicates that the operation frees some resource that has been allocated.
The following effect indicates that the operation reads from some resource.
The following effect indicates that the operation writes to some resource.
static bool isEquivalentTo(Operation *lhs, Operation *rhs, function_ref< LogicalResult(Value, Value)> checkEquivalent, function_ref< void(Value, Value)> markEquivalent=nullptr, Flags flags=Flags::None, function_ref< LogicalResult(ValueRange, ValueRange)> checkCommutativeEquivalent=nullptr)
Compare two operations (including their regions) and return if they are equivalent.